forked from mindspore-Ecosystem/mindspore
!7735 fix gpu multinomial seed issue and GRAPH_MODE
Merge pull request !7735 from baihuawei/fixmultinomial
This commit is contained in:
commit
5b28016b4d
|
@ -102,13 +102,15 @@ __global__ void MultinomialKernel(int seed, T *input, int num_sample, curandStat
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Multinomial(int seed, T *input, int num_sample, curandState *globalState, int *output, size_t distributions,
|
||||
size_t categories, cudaStream_t cuda_stream) {
|
||||
void Multinomial(int seed, int seed2, T *input, int num_sample, curandState *globalState, int *output,
|
||||
size_t distributions, size_t categories, cudaStream_t cuda_stream) {
|
||||
int RNG_seed = 0;
|
||||
if (seed != 0) {
|
||||
std::random_device rd;
|
||||
if (seed2 != 0) {
|
||||
RNG_seed = seed2;
|
||||
} else if (seed != 0) {
|
||||
RNG_seed = seed;
|
||||
} else {
|
||||
std::random_device rd;
|
||||
RNG_seed = static_cast<int>(rd());
|
||||
}
|
||||
int count = distributions * num_sample;
|
||||
|
@ -117,8 +119,8 @@ void Multinomial(int seed, T *input, int num_sample, curandState *globalState, i
|
|||
return;
|
||||
}
|
||||
|
||||
template void Multinomial<float>(int seed, float *input, int num_sample, curandState *globalState, int *output,
|
||||
size_t distributions, size_t categories, cudaStream_t cuda_stream);
|
||||
template void Multinomial<float>(int seed, int seed2, float *input, int num_sample, curandState *globalState,
|
||||
int *output, size_t distributions, size_t categories, cudaStream_t cuda_stream);
|
||||
template void CheckNonNeg<float>(const size_t size, const float *input, float *output, cudaStream_t cuda_stream);
|
||||
template void CheckZero<float>(const size_t distributions, const size_t categories, const float *input, float *output,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
|
@ -20,8 +20,8 @@
|
|||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
void Multinomial(int seed, T *input, int num_sample, curandState *globalState, int *output, size_t distributions,
|
||||
size_t categories, cudaStream_t cuda_stream);
|
||||
void Multinomial(int seed, int seed2, T *input, int num_sample, curandState *globalState, int *output,
|
||||
size_t distributions, size_t categories, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void CheckNonNeg(const size_t size, const T *input, T *output, cudaStream_t stream);
|
||||
template <typename T>
|
||||
|
|
|
@ -32,7 +32,13 @@ namespace kernel {
|
|||
template <typename T>
|
||||
class MultinomialGpuKernel : public GpuKernel {
|
||||
public:
|
||||
MultinomialGpuKernel() : input_size_0_(0), output_size_(0), distributions_(0), workspace_size_(sizeof(curandState)) {}
|
||||
MultinomialGpuKernel()
|
||||
: input_size_0_(0),
|
||||
output_size_(0),
|
||||
distributions_(0),
|
||||
workspace_size_(sizeof(curandState)),
|
||||
seed_(0),
|
||||
seed2_(0) {}
|
||||
~MultinomialGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
|
@ -52,7 +58,7 @@ class MultinomialGpuKernel : public GpuKernel {
|
|||
IntToSize(categories), 1, false, false, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
NormInput(cum_sum_input, IntToSize(distributions_), IntToSize(categories),
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
Multinomial(seed_, cum_sum_input, num_sample, devStates, output_addr, IntToSize(distributions_),
|
||||
Multinomial(seed_, seed2_, cum_sum_input, num_sample, devStates, output_addr, IntToSize(distributions_),
|
||||
IntToSize(categories), reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
@ -87,6 +93,7 @@ class MultinomialGpuKernel : public GpuKernel {
|
|||
}
|
||||
workspace_size_ = output_size_;
|
||||
seed_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed"));
|
||||
seed2_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2"));
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
@ -106,6 +113,7 @@ class MultinomialGpuKernel : public GpuKernel {
|
|||
size_t distributions_;
|
||||
size_t workspace_size_;
|
||||
int seed_;
|
||||
int seed2_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
|
|
@ -33,7 +33,7 @@ class Categorical(Distribution):
|
|||
Args:
|
||||
probs (Tensor, list, numpy.ndarray, Parameter): Event probabilities.
|
||||
seed (int): The global seed is used in sampling. Global seed is used if it is None. Default: None.
|
||||
dtype (mindspore.dtype): The type of the distribution. Default: mstype.int32.
|
||||
dtype (mindspore.dtype): The type of the event samples. Default: mstype.int32.
|
||||
name (str): The name of the distribution. Default: Categorical.
|
||||
|
||||
Note:
|
||||
|
|
|
@ -112,6 +112,11 @@ def is_same_type(inst, type_):
|
|||
"""
|
||||
return inst == type_
|
||||
|
||||
@constexpr
|
||||
def check_valid_dim(dim, name):
|
||||
if dim not in (1, 2):
|
||||
raise ValueError(
|
||||
f"For {name}, inputs dim must be 1d or 2d")
|
||||
|
||||
@constexpr
|
||||
def check_valid_type(data_type, value_type, name):
|
||||
|
|
|
@ -205,7 +205,7 @@ def poisson(shape, mean, seed=None):
|
|||
value = random_poisson(shape, mean)
|
||||
return value
|
||||
|
||||
def multinomial(inputs, num_sample, replacement=True, seed=0):
|
||||
def multinomial(inputs, num_sample, replacement=True, seed=None):
|
||||
r"""
|
||||
Returns a tensor sampled from the multinomial probability distribution located in the corresponding
|
||||
row of the input tensor.
|
||||
|
@ -232,18 +232,18 @@ def multinomial(inputs, num_sample, replacement=True, seed=0):
|
|||
"""
|
||||
shape = P.Shape()
|
||||
reshape = P.Reshape()
|
||||
if inputs.dim() != 1 and inputs.dim() != 2:
|
||||
const_utils.raise_value_error("inputs dim must be 1d or 2d")
|
||||
const_utils.check_valid_dim(len(shape(inputs)), "multinomial")
|
||||
seed1, seed2 = _get_seed(seed, "multinomial")
|
||||
if not replacement:
|
||||
if shape(inputs)[-1] < num_sample:
|
||||
const_utils.raise_value_error("num_sample must be less than shape(input)[-1] without replacement")
|
||||
n_dist = 1
|
||||
if len(shape(inputs)) > 1:
|
||||
n_dist = shape(inputs)[-2]
|
||||
random_uniform = P.UniformReal(seed=seed)((n_dist * shape(inputs)[-1],))
|
||||
random_uniform = P.UniformReal(seed1, seed2)((n_dist * shape(inputs)[-1],))
|
||||
if n_dist != 1:
|
||||
random_uniform = reshape(random_uniform, (n_dist, shape(inputs)[-1]))
|
||||
vals = P.RealDiv()(P.Log()(random_uniform), inputs + 1e-6)
|
||||
_, indices = P.TopK()(vals, num_sample)
|
||||
return indices
|
||||
return P.Multinomial(seed=seed)(inputs, num_sample)
|
||||
return P.Multinomial(seed1, seed2)(inputs, num_sample)
|
||||
|
|
|
@ -433,8 +433,8 @@ class Multinomial(PrimitiveWithInfer):
|
|||
The rows of input do not need to sum to one (in which case we use the values as weights),
|
||||
but must be non-negative, finite and have a non-zero sum.
|
||||
Args:
|
||||
seed (int): Seed data is used as entropy source for Random number engines to generate pseudo-random numbers.
|
||||
Must be non-negative. Default: 0.
|
||||
seed (int): Random seed, must be non-negative. Default: 0.
|
||||
seed2 (int): Random seed2, must be non-negative. Default: 0.
|
||||
Inputs:
|
||||
- **input** (Tensor[float32]) - the input tensor containing the cumsum of probabilities, must be 1 or 2
|
||||
dimensions.
|
||||
|
@ -450,10 +450,10 @@ class Multinomial(PrimitiveWithInfer):
|
|||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, seed=0):
|
||||
def __init__(self, seed=0, seed2=0):
|
||||
"""init"""
|
||||
Validator.check_value_type("seed", seed, [int], self.name)
|
||||
Validator.check_non_negative_int(seed, "seed", self.name)
|
||||
Validator.check_non_negative_int(seed2, "seed2", self.name)
|
||||
self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output'])
|
||||
|
||||
def __infer__(self, inputs, num_samples):
|
||||
|
|
|
@ -17,9 +17,20 @@ import numpy as np
|
|||
import pytest
|
||||
from mindspore.ops import composite as C
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
|
||||
context.set_context(device_target='GPU')
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, sample, replacement, seed=0):
|
||||
super(Net, self).__init__()
|
||||
self.sample = sample
|
||||
self.replacement = replacement
|
||||
self.seed = seed
|
||||
|
||||
def construct(self, x):
|
||||
return C.multinomial(x, self.sample, self.replacement, self.seed)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -27,9 +38,12 @@ context.set_context(device_target='GPU')
|
|||
def test_multinomial():
|
||||
x0 = Tensor(np.array([0.9, 0.2]).astype(np.float32))
|
||||
x1 = Tensor(np.array([[0.9, 0.2], [0.9, 0.2]]).astype(np.float32))
|
||||
out0 = C.multinomial(x0, 1, True)
|
||||
out1 = C.multinomial(x0, 2, True)
|
||||
out2 = C.multinomial(x1, 6, True)
|
||||
net0 = Net(1, True, 20)
|
||||
net1 = Net(2, True, 20)
|
||||
net2 = Net(6, True, 20)
|
||||
out0 = net0(x0)
|
||||
out1 = net1(x0)
|
||||
out2 = net2(x1)
|
||||
assert out0.asnumpy().shape == (1,)
|
||||
assert out1.asnumpy().shape == (2,)
|
||||
assert out2.asnumpy().shape == (2, 6)
|
||||
|
|
Loading…
Reference in New Issue