forked from mindspore-Ecosystem/mindspore
[feat][assistant][I5EWHM] Modify Multinomial GPU from DeprecatedNativeGpuKernellMod to NativeGpuKernelMod
This commit is contained in:
parent
378b787bf1
commit
e46bfc0905
|
@ -14,7 +14,7 @@ mindspore.ops.Multinomial
|
|||
- **dtype** (dtype) - 输出数据类型,必须是int32或者int64,默认类型:int32。
|
||||
|
||||
输入:
|
||||
- **x** (Tensor) - 包含累加概率和的输入Tensor,必须是一维或二维。具有float16、float32或float64数据类型。CPU和GPU后端支持一维或者二维,Ascend后端仅支持二维。
|
||||
- **x** (Tensor) - 包含累加概率和的输入Tensor,必须是1维或2维。CPU和GPU后端支持1维或者2维,Ascend后端仅支持2维。
|
||||
- **num_samples** (int) - 要抽取的样本数。
|
||||
|
||||
输出:
|
||||
|
@ -22,7 +22,6 @@ mindspore.ops.Multinomial
|
|||
|
||||
异常:
|
||||
- **TypeError** - 如果 `seed` 或者 `seed2` 不是int类型。
|
||||
- **TypeError** - 如果 `x` 不是数据类型为float16、float32或者float64的Tensor。
|
||||
- **TypeError** - 如果 `num_sample` 不是int类型。
|
||||
- **TypeError** - 如果 `dtype` 不是int32或者int64类型。
|
||||
- **ValueError** - 如果 `seed` 或者 `seed2` 小于零。
|
||||
|
|
|
@ -31,6 +31,20 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
// clang-format off
|
||||
#define ADD_KERNEL(prob_dtype, prob_type) \
|
||||
{KernelAttr().AddInputAttr(kNumberType##prob_dtype).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), \
|
||||
&MultinomialCpuKernelMod::LaunchKernel<prob_type, int32_t>}, \
|
||||
{KernelAttr().AddInputAttr(kNumberType##prob_dtype).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), \
|
||||
&MultinomialCpuKernelMod::LaunchKernel<prob_type, int64_t>}, \
|
||||
{KernelAttr().AddInputAttr(kNumberType##prob_dtype).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), \
|
||||
&MultinomialCpuKernelMod::LaunchKernel<prob_type, int32_t>}, \
|
||||
{KernelAttr().AddInputAttr(kNumberType##prob_dtype).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), \
|
||||
&MultinomialCpuKernelMod::LaunchKernel<prob_type, int64_t>}
|
||||
// clang-format on
|
||||
} // namespace
|
||||
|
||||
bool MultinomialCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->GetPrim()->name();
|
||||
|
@ -45,7 +59,6 @@ bool MultinomialCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const s
|
|||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
input_shape_ = inputs[0]->GetShapeVector();
|
||||
output_dtype_ = outputs[0]->GetDtype();
|
||||
input0_dtype_ = inputs[0]->GetDtype();
|
||||
input1_dtype_ = inputs[1]->GetDtype();
|
||||
|
@ -74,15 +87,9 @@ int MultinomialCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
|
|||
return ret;
|
||||
}
|
||||
|
||||
input_shape_ = inputs[0]->GetShapeVector();
|
||||
int64_t elem_num = std::accumulate(input_shape_.begin(), input_shape_.end(), 1, std::multiplies<size_t>());
|
||||
|
||||
if (input0_dtype_ == kNumberTypeFloat16) {
|
||||
(void)workspace_size_list_.emplace_back(elem_num * sizeof(float16));
|
||||
} else if (input0_dtype_ == kNumberTypeFloat32) {
|
||||
(void)workspace_size_list_.emplace_back(elem_num * sizeof(float));
|
||||
} else if (input0_dtype_ == kNumberTypeFloat64) {
|
||||
(void)workspace_size_list_.emplace_back(elem_num * sizeof(double));
|
||||
}
|
||||
(void)workspace_size_list_.emplace_back(elem_num * sizeof(TypeIdToType(input0_dtype_)));
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -169,28 +176,9 @@ bool MultinomialCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr>
|
|||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, MultinomialCpuKernelMod::MultinomialFunc>> MultinomialCpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
&MultinomialCpuKernelMod::LaunchKernel<float, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
&MultinomialCpuKernelMod::LaunchKernel<float16, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
&MultinomialCpuKernelMod::LaunchKernel<double, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
&MultinomialCpuKernelMod::LaunchKernel<float, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
&MultinomialCpuKernelMod::LaunchKernel<float16, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
&MultinomialCpuKernelMod::LaunchKernel<float, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
&MultinomialCpuKernelMod::LaunchKernel<float16, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
&MultinomialCpuKernelMod::LaunchKernel<double, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&MultinomialCpuKernelMod::LaunchKernel<float, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&MultinomialCpuKernelMod::LaunchKernel<float16, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&MultinomialCpuKernelMod::LaunchKernel<double, int64_t>}};
|
||||
ADD_KERNEL(Float16, float16), ADD_KERNEL(Float32, float), ADD_KERNEL(Float64, double), ADD_KERNEL(Int8, int8_t),
|
||||
ADD_KERNEL(Int16, int16_t), ADD_KERNEL(Int32, int32_t), ADD_KERNEL(Int64, int64_t), ADD_KERNEL(UInt8, uint8_t),
|
||||
ADD_KERNEL(UInt16, uint16_t), ADD_KERNEL(UInt32, uint32_t), ADD_KERNEL(UInt64, uint64_t)};
|
||||
|
||||
std::vector<KernelAttr> MultinomialCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
|
|
|
@ -89,8 +89,8 @@ __device__ int BinarySearchForMultinomial(T *start_addr, int size, T rand) {
|
|||
return start;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void MultinomialKernel(int row, int col, T *probs, curandState *state, int64_t *num_sample, int *output) {
|
||||
template <typename T, typename S>
|
||||
__global__ void MultinomialKernel(int row, int col, T *probs, curandState *state, int64_t *num_sample, S *output) {
|
||||
// Load the probs to shared memory.
|
||||
extern __shared__ float accum_probs[];
|
||||
int gid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
@ -103,7 +103,7 @@ __global__ void MultinomialKernel(int row, int col, T *probs, curandState *state
|
|||
accum_probs[shm_base_index] = probs[probs_base_index];
|
||||
for (int i = 1; i < col; i++) {
|
||||
probs_base_index++;
|
||||
accum_probs[shm_base_index + i] = accum_probs[shm_base_index + i - 1] + probs[probs_base_index];
|
||||
accum_probs[shm_base_index + i] = accum_probs[shm_base_index + i - 1] + static_cast<float>(probs[probs_base_index]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
@ -119,14 +119,13 @@ __global__ void MultinomialKernel(int row, int col, T *probs, curandState *state
|
|||
auto local_state = state[gid];
|
||||
for (int i = 0; i < num_sample[0]; i++) {
|
||||
float rand = curand_uniform(&local_state);
|
||||
output[output_base_index + i] = BinarySearchForMultinomial(&accum_probs[shm_base_index], col, rand);
|
||||
output[output_base_index + i] = static_cast<S>(BinarySearchForMultinomial(&accum_probs[shm_base_index], col, rand));
|
||||
}
|
||||
state[gid] = local_state;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Multinomial(int row, int col, T *probs, curandState *state, int64_t *num_sample, int *output,
|
||||
cudaStream_t stream) {
|
||||
template <typename T, typename S>
|
||||
void Multinomial(int row, int col, T *probs, curandState *state, int64_t *num_sample, S *output, cudaStream_t stream) {
|
||||
// Every block process several rows. It depends on shared memory usage.
|
||||
constexpr int max_shm_used_per_block = 256;
|
||||
int block_dim = std::max(Floor(std::min(row, max_shm_used_per_block), col), 1);
|
||||
|
@ -136,8 +135,51 @@ void Multinomial(int row, int col, T *probs, curandState *state, int64_t *num_sa
|
|||
MultinomialKernel<<<grid_dim, block_dim, shm_size, stream>>>(row, col, probs, state, num_sample, output);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void Multinomial<float>(int row, int col, float *probs, curandState *state,
|
||||
int64_t *num_sample, int *output, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void Multinomial<float, int64_t>(int row, int col, float *probs, curandState *state,
|
||||
int64_t *num_sample, int64_t *output, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void Multinomial<double, int64_t>(int row, int col, double *probs, curandState *state,
|
||||
int64_t *num_sample, int64_t *output, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void Multinomial<half, int64_t>(int row, int col, half *probs, curandState *state,
|
||||
int64_t *num_sample, int64_t *output, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void Multinomial<int8_t, int64_t>(int row, int col, int8_t *probs, curandState *state,
|
||||
int64_t *num_sample, int64_t *output, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void Multinomial<int16_t, int64_t>(int row, int col, int16_t *probs, curandState *state,
|
||||
int64_t *num_sample, int64_t *output, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void Multinomial<int32_t, int64_t>(int row, int col, int32_t *probs, curandState *state,
|
||||
int64_t *num_sample, int64_t *output, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void Multinomial<int64_t, int64_t>(int row, int col, int64_t *probs, curandState *state,
|
||||
int64_t *num_sample, int64_t *output, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void Multinomial<uint8_t, int64_t>(int row, int col, uint8_t *probs, curandState *state,
|
||||
int64_t *num_sample, int64_t *output, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void Multinomial<uint16_t, int64_t>(int row, int col, uint16_t *probs, curandState *state,
|
||||
int64_t *num_sample, int64_t *output, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void Multinomial<uint32_t, int64_t>(int row, int col, uint32_t *probs, curandState *state,
|
||||
int64_t *num_sample, int64_t *output, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void Multinomial<uint64_t, int64_t>(int row, int col, uint64_t *probs, curandState *state,
|
||||
int64_t *num_sample, int64_t *output, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void Multinomial<float, int32_t>(int row, int col, float *probs, curandState *state,
|
||||
int64_t *num_sample, int32_t *output, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void Multinomial<double, int32_t>(int row, int col, double *probs, curandState *state,
|
||||
int64_t *num_sample, int32_t *output, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void Multinomial<half, int32_t>(int row, int col, half *probs, curandState *state,
|
||||
int64_t *num_sample, int32_t *output, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void Multinomial<int8_t, int32_t>(int row, int col, int8_t *probs, curandState *state,
|
||||
int64_t *num_sample, int32_t *output, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void Multinomial<int16_t, int32_t>(int row, int col, int16_t *probs, curandState *state,
|
||||
int64_t *num_sample, int32_t *output, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void Multinomial<int32_t, int32_t>(int row, int col, int32_t *probs, curandState *state,
|
||||
int64_t *num_sample, int32_t *output, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void Multinomial<int64_t, int32_t>(int row, int col, int64_t *probs, curandState *state,
|
||||
int64_t *num_sample, int32_t *output, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void Multinomial<uint8_t, int32_t>(int row, int col, uint8_t *probs, curandState *state,
|
||||
int64_t *num_sample, int32_t *output, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void Multinomial<uint16_t, int32_t>(int row, int col, uint16_t *probs, curandState *state,
|
||||
int64_t *num_sample, int32_t *output, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void Multinomial<uint32_t, int32_t>(int row, int col, uint32_t *probs, curandState *state,
|
||||
int64_t *num_sample, int32_t *output, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void Multinomial<uint64_t, int32_t>(int row, int col, uint64_t *probs, curandState *state,
|
||||
int64_t *num_sample, int32_t *output, cudaStream_t stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CheckNonNeg<float>(const size_t size, const float *input, float *output,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CheckZero<float>(const size_t distributions, const size_t categories, const float *input,
|
||||
|
|
|
@ -20,8 +20,8 @@
|
|||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
|
||||
CUDA_LIB_EXPORT void InitRandState(int seed, int num, curandState *state, cudaStream_t stream);
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void Multinomial(int row, int col, T *probs, curandState *rand_state, int64_t *num_sample, int *output,
|
||||
template <typename T, typename S>
|
||||
CUDA_LIB_EXPORT void Multinomial(int row, int col, T *probs, curandState *rand_state, int64_t *num_sample, S *output,
|
||||
cudaStream_t stream);
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void CheckNonNeg(const size_t size, const T *input, T *output, cudaStream_t stream);
|
||||
|
|
|
@ -74,11 +74,11 @@ bool MultinomialGpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inpu
|
|||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename S>
|
||||
void MultinomialGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs, void *stream_ptr) {
|
||||
int *output_addr = GetDeviceAddress<int>(outputs, 0);
|
||||
T *probs_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
S *output_addr = GetDeviceAddress<S>(outputs, 0);
|
||||
int64_t *num_sample_addr = GetDeviceAddress<int64_t>(inputs, 1);
|
||||
if (distributions_ == 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', divide by zero. the distributions_ is 0.";
|
||||
|
@ -103,8 +103,50 @@ void MultinomialGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr>
|
|||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, MultinomialGpuKernelMod::LaunchFunc>> MultinomialGpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
&MultinomialGpuKernelMod::LaunchKernel<half, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
&MultinomialGpuKernelMod::LaunchKernel<float>}};
|
||||
&MultinomialGpuKernelMod::LaunchKernel<float, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
&MultinomialGpuKernelMod::LaunchKernel<double, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
&MultinomialGpuKernelMod::LaunchKernel<int8_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
&MultinomialGpuKernelMod::LaunchKernel<int16_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
&MultinomialGpuKernelMod::LaunchKernel<int32_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
&MultinomialGpuKernelMod::LaunchKernel<int64_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
&MultinomialGpuKernelMod::LaunchKernel<uint8_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
&MultinomialGpuKernelMod::LaunchKernel<uint16_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
&MultinomialGpuKernelMod::LaunchKernel<uint32_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
&MultinomialGpuKernelMod::LaunchKernel<uint64_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&MultinomialGpuKernelMod::LaunchKernel<half, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&MultinomialGpuKernelMod::LaunchKernel<float, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&MultinomialGpuKernelMod::LaunchKernel<double, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&MultinomialGpuKernelMod::LaunchKernel<int8_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&MultinomialGpuKernelMod::LaunchKernel<int16_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&MultinomialGpuKernelMod::LaunchKernel<int32_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&MultinomialGpuKernelMod::LaunchKernel<int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&MultinomialGpuKernelMod::LaunchKernel<uint8_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&MultinomialGpuKernelMod::LaunchKernel<uint16_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&MultinomialGpuKernelMod::LaunchKernel<uint32_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&MultinomialGpuKernelMod::LaunchKernel<uint64_t, int64_t>}};
|
||||
|
||||
std::vector<KernelAttr> MultinomialGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
|
|
|
@ -31,7 +31,6 @@
|
|||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/multinomial_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cumsum_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -59,7 +58,7 @@ class MultinomialGpuKernelMod : public NativeGpuKernelMod {
|
|||
bool rand_state_init_{false};
|
||||
curandState *rand_state_{nullptr};
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename S>
|
||||
void LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs,
|
||||
void *stream_ptr);
|
||||
using LaunchFunc = std::function<void(MultinomialGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
|
|
|
@ -111,7 +111,8 @@ TypePtr MultinomialInferType(const PrimitivePtr &prim, const std::vector<Abstrac
|
|||
auto prim_name = prim->name();
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
auto num_samples_type = input_args[1]->BuildType();
|
||||
const std::set valid_types_1 = {kFloat16, kFloat32, kFloat64};
|
||||
const std::set valid_types_1 = {kFloat16, kFloat32, kFloat64, kInt8, kInt16, kInt32,
|
||||
kInt64, kUInt8, kUInt16, kUInt32, kUInt64};
|
||||
const std::set valid_types_2 = {kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types_1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTypeValid("num_samples", num_samples_type, valid_types_2, prim_name);
|
||||
|
@ -165,6 +166,8 @@ class MIND_API AGMultinomialInfer : public abstract::OpInferBase {
|
|||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
return MultinomialInfer(engine, primitive, input_args);
|
||||
}
|
||||
|
||||
std::set<int64_t> GetValueDependArgIndices() const override { return {1}; }
|
||||
};
|
||||
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(Multinomial, prim::kPrimMultinomial, AGMultinomialInfer, false);
|
||||
|
|
|
@ -789,8 +789,7 @@ class Multinomial(Primitive):
|
|||
|
||||
Inputs:
|
||||
- **x** (Tensor) - the input tensor containing the cumsum of probabilities, must be 1 or 2
|
||||
dimensions. Must be one of the following types: float16, float32, float64. CPU and GPU
|
||||
supports x 1 or 2 dimensions and Ascend only supports 2 dimensions.
|
||||
dimensions. CPU and GPU supports x 1 or 2 dimensions and Ascend only supports 2 dimensions.
|
||||
- **num_samples** (int) - number of samples to draw, must be a nonnegative number.
|
||||
|
||||
Outputs:
|
||||
|
@ -798,7 +797,6 @@ class Multinomial(Primitive):
|
|||
|
||||
Raises:
|
||||
TypeError: If neither `seed` nor `seed2` is an int.
|
||||
TypeError: If `x` is not a Tensor whose dtype is float16, float32, float64.
|
||||
TypeError: If dtype of `num_samples` is not int.
|
||||
TypeError: If `dtype` is not int32 or int64.
|
||||
ValueError: If `seed` or `seed2` is less than 0.
|
||||
|
|
Loading…
Reference in New Issue