forked from mindspore-Ecosystem/mindspore
commit
9297ba0a8d
|
@ -33,7 +33,12 @@ namespace kernel {
|
|||
template <typename T>
|
||||
class MultinomialGpuKernel : public GpuKernel {
|
||||
public:
|
||||
MultinomialGpuKernel() : input_size_0_(0), output_size_(0), distributions_(0), workspace_size_(sizeof(curandState)) {}
|
||||
MultinomialGpuKernel()
|
||||
: input_size_0_(0),
|
||||
output_size_(0),
|
||||
distributions_(0),
|
||||
workspace_size_(sizeof(curandState)),
|
||||
replacement_(true) {}
|
||||
~MultinomialGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
|
@ -49,6 +54,19 @@ class MultinomialGpuKernel : public GpuKernel {
|
|||
int categories = SizeToInt(inputs[0]->size / sizeof(T)) / distributions_;
|
||||
int num_sample = SizeToInt(outputs[0]->size / sizeof(T)) / distributions_;
|
||||
// check input
|
||||
T *cum_sum_input = nullptr;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMalloc(reinterpret_cast<void **>(&cum_sum_input), input_size_0_),
|
||||
"cudaMalloc failed.");
|
||||
CheckPeram(input_addr, cum_sum_input, categories, stream_ptr);
|
||||
if (replacement_) {
|
||||
Multinomial(seed_, cum_sum_input, num_sample, devStates, output_addr, IntToSize(distributions_),
|
||||
IntToSize(categories), reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFree(cum_sum_input), "cudaFree failed.");
|
||||
return true;
|
||||
}
|
||||
|
||||
void CheckPeram(const T *input_addr, T *cum_sum_input, int categories, void *stream_ptr) {
|
||||
T *flag = nullptr;
|
||||
T *cflag = nullptr;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMalloc(reinterpret_cast<void **>(&cflag), sizeof(T)), "cudaMalloc failed.");
|
||||
|
@ -67,9 +85,6 @@ class MultinomialGpuKernel : public GpuKernel {
|
|||
if (*flag > 0) {
|
||||
MS_LOG(EXCEPTION) << "Input is invalid (input element < 0)";
|
||||
}
|
||||
T *cum_sum_input = nullptr;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMalloc(reinterpret_cast<void **>(&cum_sum_input), input_size_0_),
|
||||
"cudaMalloc failed.");
|
||||
CumSum(input_addr, cum_sum_input, cum_sum_input, IntToSize(distributions_), IntToSize(categories), 1,
|
||||
IntToSize(categories), 1, false, false, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
|
@ -82,14 +97,10 @@ class MultinomialGpuKernel : public GpuKernel {
|
|||
if (*flag > 0) {
|
||||
MS_LOG(EXCEPTION) << "Input is invalid (sum <= 0)";
|
||||
}
|
||||
Multinomial(seed_, cum_sum_input, num_sample, devStates, output_addr, IntToSize(distributions_),
|
||||
IntToSize(categories), reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFree(cum_sum_input), "cudaFree failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFree(cflag), "cudaFree failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(flag), "cudaFreeHost failed.");
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
|
@ -114,9 +125,15 @@ class MultinomialGpuKernel : public GpuKernel {
|
|||
}
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
output_size_ = sizeof(int);
|
||||
workspace_size_ = sizeof(int);
|
||||
replacement_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("replacement"));
|
||||
if (replacement_) {
|
||||
for (size_t i = 0; i < output_shape.size(); i++) {
|
||||
output_size_ *= output_shape[i];
|
||||
workspace_size_ *= output_shape[i];
|
||||
}
|
||||
}
|
||||
if (replacement_) {
|
||||
workspace_size_ = output_size_;
|
||||
}
|
||||
seed_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed"));
|
||||
InitSizeLists();
|
||||
|
@ -136,6 +153,7 @@ class MultinomialGpuKernel : public GpuKernel {
|
|||
size_t output_size_;
|
||||
size_t distributions_;
|
||||
size_t workspace_size_;
|
||||
bool replacement_;
|
||||
int seed_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
|
|
|
@ -20,8 +20,6 @@ from .. import functional as F
|
|||
from ..primitive import constexpr
|
||||
from .multitype_ops import _constexpr_utils as const_utils
|
||||
from ...common import dtype as mstype
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
|
||||
# set graph-level RNG seed
|
||||
_GRAPH_SEED = 0
|
||||
|
@ -204,14 +202,13 @@ def multinomial(inputs, num_sample, replacement=True, seed=0):
|
|||
Note:
|
||||
The rows of input do not need to sum to one (in which case we use the values as weights),
|
||||
but must be non-negative, finite and have a non-zero sum.
|
||||
Args:
|
||||
seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers.
|
||||
Default: 0.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - the input tensor containing probabilities, must be 1 or 2 dims.
|
||||
- **num_samples** (int) - number of samples to draw.
|
||||
- **replacement** (bool, optional) - whether to draw with replacement or not, default True.
|
||||
Args:
|
||||
input (Tensor) - the input tensor containing probabilities, must be 1 or 2 dims.
|
||||
num_samples (int) - number of samples to draw.
|
||||
replacement (bool, optional) - whether to draw with replacement or not, default True.
|
||||
seed (int, optional) - used as entropy source for Random number engines generating pseudo-random numbers.
|
||||
Must be non-negative. Default: 0.
|
||||
|
||||
Outputs:
|
||||
Tensor. have the same rows with input, each row has num_samples sampled indices.
|
||||
|
@ -222,21 +219,19 @@ def multinomial(inputs, num_sample, replacement=True, seed=0):
|
|||
"""
|
||||
shape = P.Shape()
|
||||
reshape = P.Reshape()
|
||||
validator.check_value_type('replacement', replacement, (bool,), None)
|
||||
validator.check_value_type('num_sample', num_sample, (int,), None)
|
||||
validator.check_integer("num_sample", num_sample, 0, Rel.GT, None)
|
||||
if inputs.dim() != 1 and inputs.dim() != 2:
|
||||
raise ValueError("inputs dim must be 1d or 2d")
|
||||
if not replacement:
|
||||
P.Multinomial(replacement=replacement, seed=seed)(inputs, num_sample)
|
||||
if shape(inputs)[-1] < num_sample:
|
||||
raise ValueError("num_sample must be less than shape(input)[-1] without replacement")
|
||||
n_dist = 1
|
||||
if len(shape(inputs)) > 1:
|
||||
n_dist = shape(inputs)[-2]
|
||||
random_uniform = P.UniformReal(seed=seed)((n_dist * num_sample,))
|
||||
random_uniform = P.UniformReal(seed=seed)((n_dist * shape(inputs)[-1],))
|
||||
if n_dist != 1:
|
||||
random_uniform = reshape(random_uniform, (n_dist, num_sample))
|
||||
random_uniform = reshape(random_uniform, (n_dist, shape(inputs)[-1]))
|
||||
vals = P.RealDiv()(P.Log()(random_uniform), inputs + 1e-6)
|
||||
_, indices = P.TopK()(vals, num_sample)
|
||||
return indices
|
||||
return P.Multinomial(seed=seed)(inputs, num_sample)
|
||||
return P.Multinomial(replacement=replacement, seed=seed)(inputs, num_sample)
|
||||
|
|
|
@ -438,11 +438,12 @@ class Multinomial(PrimitiveWithInfer):
|
|||
but must be non-negative, finite and have a non-zero sum.
|
||||
Args:
|
||||
seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers.
|
||||
Default: 0.
|
||||
Must be non-negative. Default: 0.
|
||||
replacement(bool) - whether to draw with replacement or not.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor[float32]) - the input tensor containing the cumsum of probabilities, must be 1 or 2 dims.
|
||||
- **num_samples** (int) - number of samples to draw.
|
||||
- **num_samples** (int32) - number of samples to draw.
|
||||
|
||||
Outputs:
|
||||
Tensor. have the same rows with input, each row has num_samples sampled indices.
|
||||
|
@ -450,13 +451,15 @@ class Multinomial(PrimitiveWithInfer):
|
|||
Examples:
|
||||
>>> input = Tensor([0., 9., 4., 0.], mstype.float32)
|
||||
>>> multinomial = P.Multinomial(seed=10)
|
||||
>>> output = multinomial(input, 2)
|
||||
>>> output = multinomial(input, 2, True)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, seed=0):
|
||||
def __init__(self, replacement=True, seed=0):
|
||||
"""init"""
|
||||
validator.check_value_type("seed", seed, [int], self.name)
|
||||
validator.check_integer("seed", seed, 0, Rel.GE, self.name)
|
||||
validator.check_value_type("replacement", replacement, [bool], self.name)
|
||||
self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output'])
|
||||
|
||||
def __infer__(self, inputs, num_samples):
|
||||
|
@ -467,7 +470,7 @@ class Multinomial(PrimitiveWithInfer):
|
|||
num_samples_value = num_samples["value"]
|
||||
if num_samples_value is None:
|
||||
raise ValueError(f"For {self.name}, shape nust be const")
|
||||
validator.check_value_type("num_samples", num_samples_value, [int], self.name)
|
||||
validator.check_value_type("num_samples", num_samples_value, (int,), self.name)
|
||||
validator.check_integer("num_samples", num_samples_value, 0, Rel.GT, None)
|
||||
y_shape = (num_samples_value,)
|
||||
if len(input_shape) == 2:
|
||||
|
|
Loading…
Reference in New Issue