diff --git a/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/nn/probability/distribution/_utils/utils.py index d49088ad48..5e81d79ae6 100644 --- a/mindspore/nn/probability/distribution/_utils/utils.py +++ b/mindspore/nn/probability/distribution/_utils/utils.py @@ -244,8 +244,8 @@ def logits_to_probs(logits, is_binary=False): is_binary (bool) """ if is_binary: - return nn.sigmoid()(logits) - return nn.softmax(axis=-1)(logits) + return nn.Sigmoid()(logits) + return nn.Softmax(axis=-1)(logits) def clamp_probs(probs): @@ -300,6 +300,9 @@ def raise_none_error(name): raise TypeError(f"the type {name} should be subclass of Tensor." f" It should not be None since it is not specified during initialization.") +@constexpr +def raise_probs_logits_error(): + raise TypeError("Either 'probs' or 'logits' must be specified, but not both.") @constexpr def raise_not_impl_error(name): diff --git a/mindspore/nn/probability/distribution/categorical.py b/mindspore/nn/probability/distribution/categorical.py index 81b4152e9f..9219841ff7 100644 --- a/mindspore/nn/probability/distribution/categorical.py +++ b/mindspore/nn/probability/distribution/categorical.py @@ -17,7 +17,7 @@ import numpy as np from mindspore.ops import operations as P from mindspore.common import dtype as mstype from .distribution import Distribution -from ._utils.utils import logits_to_probs, probs_to_logits, check_tensor_type, cast_to_tensor +from ._utils.utils import logits_to_probs, probs_to_logits, check_type, check_tensor_type, cast_to_tensor, raise_probs_logits_error class Categorical(Distribution): @@ -71,9 +71,11 @@ class Categorical(Distribution): dtype=mstype.int32, name="Categorical"): param = dict(locals()) + valid_dtype = mstype.int_type + check_type(dtype, valid_dtype, "Categorical") super(Categorical, self).__init__(seed, dtype, name, param) if (probs is None) == (logits is None): - raise ValueError("Either 'prob' or 'logits' must be specified, but not both.") + raise_probs_logits_error() self.reduce_sum = P.ReduceSum(keep_dims=True) self.log = P.Log() self.exp = P.Exp() @@ -127,8 +129,7 @@ class Categorical(Distribution): Returns: Tensor, shape is shape(probs)[:-1] + sample_shape """ - if not isinstance(sample_shape, tuple): - raise ValueError("sample shape must be a tuple") + self.checktuple(sample_shape, 'shape') num_sample = 1 for i in sample_shape: num_sample *= i @@ -136,7 +137,7 @@ class Categorical(Distribution): samples = self.mutinomial(probs_2d, num_sample) extend_shape = sample_shape if len(self.shape(self._probs)) > 1: - extend_shape = self.shape(self._probs)[:-1] + sample_shape + extend_shape = sample_shape + self.shape(self._probs)[:-1] return self.cast(self.reshape(samples, extend_shape), self.dtype) def _broad_cast_shape(self, a, b): @@ -183,15 +184,16 @@ class Categorical(Distribution): if value is not None: check_tensor_type("value", value, [mstype.float32, bool, mstype.int32]) value = self.expandim(self.cast(value, mstype.float32), -1) - broad_shape = self._broad_cast_shape(value, self._logits) + index = cast_to_tensor(np.arange(self.shape(value)[0]).astype(np.float32)) + index = self.expandim(index, -1) + logits = self._logits if self._logits.dim() == 1 else self.expandim(self._logits, 0) + broad_shape = self._broad_cast_shape(value, logits) broad = P.BroadcastTo(broad_shape) value = broad(value)[..., :1] - index = cast_to_tensor(np.arange(broad_shape[-1]).astype(np.float32)) - index = self.expandim(index, -1) index = broad(index)[..., :1] value = self.concat((index, value)) value = self.cast(value, mstype.int32) - return self.gather(self._logits, value) + return self.gather(logits, value) return None def _entropy(self): @@ -209,7 +211,7 @@ class Categorical(Distribution): Enumerate categories. """ num_events = self._num_events - values = cast_to_tensor(np.arange(num_events).astype(np.int32), mstype.int32) + values = cast_to_tensor(np.arange(num_events).astype(np.int32), mstype.float32) values = self.reshape(values, (num_events, 1)) if expand: values = P.BroadcastTo((num_events, self._batch_shape))(values) diff --git a/mindspore/ops/composite/random_ops.py b/mindspore/ops/composite/random_ops.py index c1d0bf52c9..7ac45d2ed7 100644 --- a/mindspore/ops/composite/random_ops.py +++ b/mindspore/ops/composite/random_ops.py @@ -204,14 +204,15 @@ def multinomial(inputs, num_sample, replacement=True, seed=0): but must be non-negative, finite and have a non-zero sum. Args: - input (Tensor) - the input tensor containing probabilities, must be 1 or 2 dims. - num_samples (int) - number of samples to draw. - replacement (bool, optional) - whether to draw with replacement or not, default True. - seed (int, optional) - used as entropy source for Random number engines generating pseudo-random numbers. + inputs (Tensor): the input tensor containing probabilities, must be 1 or 2 dims. With float32 data type. + num_sample (int): number of samples to draw. + replacement (bool, optional): whether to draw with replacement or not, default True. + seed (int, optional): used as entropy source for Random number engines generating pseudo-random numbers. Must be non-negative. Default: 0. Outputs: Tensor. have the same rows with input, each row has num_samples sampled indices. + The dtype is float32. Examples: >>> input = Tensor([0, 9, 4, 0], mstype.float32)