diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cu new file mode 100644 index 00000000000..57730ca0182 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cu @@ -0,0 +1,118 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "multinomial_impl.cuh" + +template +__global__ void NormInput(T *input, const size_t distributions, const size_t categories) { + size_t size = distributions * categories; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if ((pos + 1) % categories != 0) { + int de_pos = (1 + pos / categories) * categories - 1; + input[pos] /= input[de_pos]; + } + } + return; +} + +template +__global__ void CheckZeroKernel(const size_t distributions, const size_t categories, const T *input, T *out) { + out[0] = 0; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (distributions); pos += blockDim.x * gridDim.x) { + if (input[(1 + pos) * categories - 1] <= 0) { + out[0] = 1; + } + } + return; +} + +template +void CheckZero(const size_t distributions, const size_t categories, const T *input, T *output, + cudaStream_t cuda_stream) { + CheckZeroKernel<<>>(distributions, categories, input, output); +} + +template +__global__ void CheckNonNegKernel(const size_t size, const T *input, T *out) { + out[0] = 0; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (input[pos] < 0) { + out[0] = 1; + } + } + return; +} + +template +void CheckNonNeg(const size_t size, const T *input, T *output, cudaStream_t cuda_stream) { + CheckNonNegKernel<<>>(size, input, output); +} + +template +__device__ int BinarySearchForMultinomial(T *start_addr, int size, T rand) { + int start = 0; + int end = size; + while (end - start > 0) { + int mid = start + (end - start) / 2; + T mid_val = start_addr[mid]; + if (mid_val < rand) { + start = mid + 1; + } else { + end = mid; + } + } + if (start == size) { + start = size - 1; + } + return start; +} + +template +__global__ void MultinomialKernel(int seed, T *input, int num_sample, curandState *globalState, int *output, + size_t distributions, size_t categories) { + int count = num_sample * distributions; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) { + int j = i / num_sample % distributions; + curand_init(seed, i, 0, &globalState[i]); + auto rand = curand_uniform(&globalState[i]); + int pick = BinarySearchForMultinomial(input + j * categories, categories, rand); + output[i] = pick; + } + return; +} + +template +void Multinomial(int seed, T *input, int num_sample, curandState *globalState, int *output, size_t distributions, + size_t categories, cudaStream_t cuda_stream) { + int RNG_seed = 0; + if (seed != 0) { + RNG_seed = seed; + } else { + RNG_seed = time(NULL); + } + int count = distributions * num_sample; + int count1 = distributions * categories; + NormInput<<>>(input, distributions, categories); + MultinomialKernel<<>>(RNG_seed, input, num_sample, globalState, + output, distributions, categories); + return; +} + +template void Multinomial(int seed, float *input, int num_sample, curandState *globalState, int *output, + size_t distributions, size_t categories, cudaStream_t cuda_stream); +template void CheckNonNeg(const size_t size, const float *input, float *output, cudaStream_t cuda_stream); +template void CheckZero(const size_t distributions, const size_t categories, const float *input, float *output, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cuh new file mode 100644 index 00000000000..097f8ef9004 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cuh @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MULTINOMIAL_IMPL_CUH_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MULTINOMIAL_IMPL_CUH_ +#include +#include "runtime/device/gpu/cuda_common.h" + +template +void Multinomial(int seed, T *input, int num_sample, curandState *globalState, int *output, size_t distributions, + size_t categories, cudaStream_t cuda_stream); +template +void CheckNonNeg(const size_t size, const T *input, T *output, cudaStream_t stream); +template +void CheckZero(const size_t distributions, const size_t categories, const T *input, T *output, cudaStream_t stream); +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MULTINOMIAL_IMPL_CUH_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.cc new file mode 100644 index 00000000000..32c5bd39ca8 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + Multinomial, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + MultinomialGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h new file mode 100644 index 00000000000..e5d01b4ac6a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h @@ -0,0 +1,141 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MULTINOMIAL_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MULTINOMIAL_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/cumsum_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class MultinomialGpuKernel : public GpuKernel { + public: + MultinomialGpuKernel() : input_size_0_(0), output_size_(0), distributions_(0), workspace_size_(sizeof(curandState)) {} + ~MultinomialGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + void *workspace_addr = GetDeviceAddress(workspace, 0); + curandState *devStates = reinterpret_cast(workspace_addr); + int *output_addr = GetDeviceAddress(outputs, 0); + T *input_addr = GetDeviceAddress(inputs, 0); + int categories = SizeToInt(inputs[0]->size / sizeof(T)) / distributions_; + int num_sample = SizeToInt(outputs[0]->size / sizeof(T)) / distributions_; + // check input + T *flag = nullptr; + T *cflag = nullptr; + CHECK_CUDA_RET_WITH_EXCEPT(cudaMalloc(reinterpret_cast(&cflag), sizeof(T)), "cudaMalloc failed."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&flag, sizeof(T)), "cudaMallocHost failed."); + CalFloatStatus(input_size_0_ / sizeof(T), input_addr, cflag, reinterpret_cast(stream_ptr)); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(flag, cflag, sizeof(T), cudaMemcpyDeviceToHost), "cudaMemcpyAsync failed."); + if (*flag > 0) { + MS_LOG(EXCEPTION) << "Input is invalid (containing NaN, -inf or inf)"; + } + CheckNonNeg(input_size_0_ / sizeof(T), input_addr, cflag, reinterpret_cast(stream_ptr)); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(flag, cflag, sizeof(T), cudaMemcpyDeviceToHost), "cudaMemcpyAsync failed."); + if (*flag > 0) { + MS_LOG(EXCEPTION) << "Input is invalid (input element < 0)"; + } + T *cum_sum_input = nullptr; + CHECK_CUDA_RET_WITH_EXCEPT(cudaMalloc(reinterpret_cast(&cum_sum_input), input_size_0_), + "cudaMalloc failed."); + CumSum(input_addr, cum_sum_input, cum_sum_input, IntToSize(distributions_), IntToSize(categories), 1, + IntToSize(categories), 1, false, false, reinterpret_cast(stream_ptr)); + CheckZero(IntToSize(distributions_), IntToSize(categories), cum_sum_input, cflag, + reinterpret_cast(stream_ptr)); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(flag, cflag, sizeof(T), cudaMemcpyDeviceToHost), "cudaMemcpyAsync failed."); + if (*flag > 0) { + MS_LOG(EXCEPTION) << "Input is invalid (sum <= 0)"; + } + CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast(stream_ptr)), + "cudaStreamSynchronize failed."); + Multinomial(seed_, cum_sum_input, num_sample, devStates, output_addr, IntToSize(distributions_), + IntToSize(categories), reinterpret_cast(stream_ptr)); + + CHECK_CUDA_RET_WITH_EXCEPT(cudaFree(cum_sum_input), "cudaFree failed."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaFree(cflag), "cudaFree failed."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(flag), "cudaFreeHost failed."); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but multinomial needs 2 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but multinomial needs 1 output."; + return false; + } + auto input_shape_0 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape_0.size() == 1) { + distributions_ = 1; + } else { + distributions_ = input_shape_0[0]; + } + input_size_0_ = sizeof(T); + for (size_t i = 0; i < input_shape_0.size(); i++) { + input_size_0_ *= input_shape_0[i]; + } + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + output_size_ = sizeof(int); + for (size_t i = 0; i < output_shape.size(); i++) { + output_size_ *= output_shape[i]; + workspace_size_ *= output_shape[i]; + } + seed_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed")); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_0_); + input_size_list_.push_back(sizeof(int)); + output_size_list_.push_back(output_size_); + workspace_size_list_.push_back(workspace_size_); + } + + private: + size_t input_size_0_; + size_t output_size_; + size_t distributions_; + size_t workspace_size_; + int seed_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MULTINOMIAL_GPU_KERNEL_H_ diff --git a/tests/st/ops/gpu/test_multinomial_op.py b/tests/st/ops/gpu/test_multinomial_op.py new file mode 100644 index 00000000000..95f71336a00 --- /dev/null +++ b/tests/st/ops/gpu/test_multinomial_op.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +from mindspore.ops import composite as C +import mindspore.context as context +from mindspore import Tensor + +context.set_context(device_target='GPU') + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_multinomial(): + x0 = Tensor(np.array([0.9, 0.2]).astype(np.float32)) + x1 = Tensor(np.array([[0.9, 0.2], [0.9, 0.2]]).astype(np.float32)) + out0 = C.multinomial(x0, 1, True) + out1 = C.multinomial(x0, 2, True) + out2 = C.multinomial(x1, 6, True) + out3 = C.multinomial(x0, 1, False) + out4 = C.multinomial(x0, 2, False) + assert out0.asnumpy().shape == (1,) + assert out1.asnumpy().shape == (2,) + assert out2.asnumpy().shape == (2, 6) + assert out3.asnumpy().shape == (1,) + assert out4.asnumpy().shape == (2,)