forked from mindspore-Ecosystem/mindspore
add multinomial backend
This commit is contained in:
parent
b331e62400
commit
40748a30c7
|
@ -0,0 +1,118 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "multinomial_impl.cuh"
|
||||
|
||||
template <typename T>
|
||||
__global__ void NormInput(T *input, const size_t distributions, const size_t categories) {
|
||||
size_t size = distributions * categories;
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
|
||||
if ((pos + 1) % categories != 0) {
|
||||
int de_pos = (1 + pos / categories) * categories - 1;
|
||||
input[pos] /= input[de_pos];
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void CheckZeroKernel(const size_t distributions, const size_t categories, const T *input, T *out) {
|
||||
out[0] = 0;
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (distributions); pos += blockDim.x * gridDim.x) {
|
||||
if (input[(1 + pos) * categories - 1] <= 0) {
|
||||
out[0] = 1;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CheckZero(const size_t distributions, const size_t categories, const T *input, T *output,
|
||||
cudaStream_t cuda_stream) {
|
||||
CheckZeroKernel<<<GET_BLOCKS(distributions), GET_THREADS, 0, cuda_stream>>>(distributions, categories, input, output);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void CheckNonNegKernel(const size_t size, const T *input, T *out) {
|
||||
out[0] = 0;
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
|
||||
if (input[pos] < 0) {
|
||||
out[0] = 1;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CheckNonNeg(const size_t size, const T *input, T *output, cudaStream_t cuda_stream) {
|
||||
CheckNonNegKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, output);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ int BinarySearchForMultinomial(T *start_addr, int size, T rand) {
|
||||
int start = 0;
|
||||
int end = size;
|
||||
while (end - start > 0) {
|
||||
int mid = start + (end - start) / 2;
|
||||
T mid_val = start_addr[mid];
|
||||
if (mid_val < rand) {
|
||||
start = mid + 1;
|
||||
} else {
|
||||
end = mid;
|
||||
}
|
||||
}
|
||||
if (start == size) {
|
||||
start = size - 1;
|
||||
}
|
||||
return start;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void MultinomialKernel(int seed, T *input, int num_sample, curandState *globalState, int *output,
|
||||
size_t distributions, size_t categories) {
|
||||
int count = num_sample * distributions;
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) {
|
||||
int j = i / num_sample % distributions;
|
||||
curand_init(seed, i, 0, &globalState[i]);
|
||||
auto rand = curand_uniform(&globalState[i]);
|
||||
int pick = BinarySearchForMultinomial(input + j * categories, categories, rand);
|
||||
output[i] = pick;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Multinomial(int seed, T *input, int num_sample, curandState *globalState, int *output, size_t distributions,
|
||||
size_t categories, cudaStream_t cuda_stream) {
|
||||
int RNG_seed = 0;
|
||||
if (seed != 0) {
|
||||
RNG_seed = seed;
|
||||
} else {
|
||||
RNG_seed = time(NULL);
|
||||
}
|
||||
int count = distributions * num_sample;
|
||||
int count1 = distributions * categories;
|
||||
NormInput<<<GET_BLOCKS(count1), GET_THREADS, 0, cuda_stream>>>(input, distributions, categories);
|
||||
MultinomialKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(RNG_seed, input, num_sample, globalState,
|
||||
output, distributions, categories);
|
||||
return;
|
||||
}
|
||||
|
||||
template void Multinomial<float>(int seed, float *input, int num_sample, curandState *globalState, int *output,
|
||||
size_t distributions, size_t categories, cudaStream_t cuda_stream);
|
||||
template void CheckNonNeg<float>(const size_t size, const float *input, float *output, cudaStream_t cuda_stream);
|
||||
template void CheckZero<float>(const size_t distributions, const size_t categories, const float *input, float *output,
|
||||
cudaStream_t cuda_stream);
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MULTINOMIAL_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MULTINOMIAL_IMPL_CUH_
|
||||
#include <curand_kernel.h>
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
void Multinomial(int seed, T *input, int num_sample, curandState *globalState, int *output, size_t distributions,
|
||||
size_t categories, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void CheckNonNeg(const size_t size, const T *input, T *output, cudaStream_t stream);
|
||||
template <typename T>
|
||||
void CheckZero(const size_t distributions, const size_t categories, const T *input, T *output, cudaStream_t stream);
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MULTINOMIAL_IMPL_CUH_
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Multinomial,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
MultinomialGpuKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,141 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MULTINOMIAL_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MULTINOMIAL_GPU_KERNEL_H_
|
||||
|
||||
#include <curand_kernel.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/cumsum_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class MultinomialGpuKernel : public GpuKernel {
|
||||
public:
|
||||
MultinomialGpuKernel() : input_size_0_(0), output_size_(0), distributions_(0), workspace_size_(sizeof(curandState)) {}
|
||||
~MultinomialGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
void *workspace_addr = GetDeviceAddress<void *>(workspace, 0);
|
||||
curandState *devStates = reinterpret_cast<curandState *>(workspace_addr);
|
||||
int *output_addr = GetDeviceAddress<int>(outputs, 0);
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
int categories = SizeToInt(inputs[0]->size / sizeof(T)) / distributions_;
|
||||
int num_sample = SizeToInt(outputs[0]->size / sizeof(T)) / distributions_;
|
||||
// check input
|
||||
T *flag = nullptr;
|
||||
T *cflag = nullptr;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMalloc(reinterpret_cast<void **>(&cflag), sizeof(T)), "cudaMalloc failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&flag, sizeof(T)), "cudaMallocHost failed.");
|
||||
CalFloatStatus(input_size_0_ / sizeof(T), input_addr, cflag, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(flag, cflag, sizeof(T), cudaMemcpyDeviceToHost), "cudaMemcpyAsync failed.");
|
||||
if (*flag > 0) {
|
||||
MS_LOG(EXCEPTION) << "Input is invalid (containing NaN, -inf or inf)";
|
||||
}
|
||||
CheckNonNeg(input_size_0_ / sizeof(T), input_addr, cflag, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(flag, cflag, sizeof(T), cudaMemcpyDeviceToHost), "cudaMemcpyAsync failed.");
|
||||
if (*flag > 0) {
|
||||
MS_LOG(EXCEPTION) << "Input is invalid (input element < 0)";
|
||||
}
|
||||
T *cum_sum_input = nullptr;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMalloc(reinterpret_cast<void **>(&cum_sum_input), input_size_0_),
|
||||
"cudaMalloc failed.");
|
||||
CumSum(input_addr, cum_sum_input, cum_sum_input, IntToSize(distributions_), IntToSize(categories), 1,
|
||||
IntToSize(categories), 1, false, false, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CheckZero(IntToSize(distributions_), IntToSize(categories), cum_sum_input, cflag,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(flag, cflag, sizeof(T), cudaMemcpyDeviceToHost), "cudaMemcpyAsync failed.");
|
||||
if (*flag > 0) {
|
||||
MS_LOG(EXCEPTION) << "Input is invalid (sum <= 0)";
|
||||
}
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaStreamSynchronize failed.");
|
||||
Multinomial(seed_, cum_sum_input, num_sample, devStates, output_addr, IntToSize(distributions_),
|
||||
IntToSize(categories), reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFree(cum_sum_input), "cudaFree failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFree(cflag), "cudaFree failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(flag), "cudaFreeHost failed.");
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 2) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but multinomial needs 2 input.";
|
||||
return false;
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but multinomial needs 1 output.";
|
||||
return false;
|
||||
}
|
||||
auto input_shape_0 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (input_shape_0.size() == 1) {
|
||||
distributions_ = 1;
|
||||
} else {
|
||||
distributions_ = input_shape_0[0];
|
||||
}
|
||||
input_size_0_ = sizeof(T);
|
||||
for (size_t i = 0; i < input_shape_0.size(); i++) {
|
||||
input_size_0_ *= input_shape_0[i];
|
||||
}
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
output_size_ = sizeof(int);
|
||||
for (size_t i = 0; i < output_shape.size(); i++) {
|
||||
output_size_ *= output_shape[i];
|
||||
workspace_size_ *= output_shape[i];
|
||||
}
|
||||
seed_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed"));
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_0_);
|
||||
input_size_list_.push_back(sizeof(int));
|
||||
output_size_list_.push_back(output_size_);
|
||||
workspace_size_list_.push_back(workspace_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
size_t input_size_0_;
|
||||
size_t output_size_;
|
||||
size_t distributions_;
|
||||
size_t workspace_size_;
|
||||
int seed_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MULTINOMIAL_GPU_KERNEL_H_
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mindspore.ops import composite as C
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
|
||||
context.set_context(device_target='GPU')
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_multinomial():
|
||||
x0 = Tensor(np.array([0.9, 0.2]).astype(np.float32))
|
||||
x1 = Tensor(np.array([[0.9, 0.2], [0.9, 0.2]]).astype(np.float32))
|
||||
out0 = C.multinomial(x0, 1, True)
|
||||
out1 = C.multinomial(x0, 2, True)
|
||||
out2 = C.multinomial(x1, 6, True)
|
||||
out3 = C.multinomial(x0, 1, False)
|
||||
out4 = C.multinomial(x0, 2, False)
|
||||
assert out0.asnumpy().shape == (1,)
|
||||
assert out1.asnumpy().shape == (2,)
|
||||
assert out2.asnumpy().shape == (2, 6)
|
||||
assert out3.asnumpy().shape == (1,)
|
||||
assert out4.asnumpy().shape == (2,)
|
Loading…
Reference in New Issue