fix Categorical

This commit is contained in:
baihuawei 2020-08-27 11:17:42 +08:00
parent ac239b6506
commit 3a7e7802b0
3 changed files with 22 additions and 16 deletions

View File

@ -244,8 +244,8 @@ def logits_to_probs(logits, is_binary=False):
is_binary (bool)
"""
if is_binary:
return nn.sigmoid()(logits)
return nn.softmax(axis=-1)(logits)
return nn.Sigmoid()(logits)
return nn.Softmax(axis=-1)(logits)
def clamp_probs(probs):
@ -300,6 +300,9 @@ def raise_none_error(name):
raise TypeError(f"the type {name} should be subclass of Tensor."
f" It should not be None since it is not specified during initialization.")
@constexpr
def raise_probs_logits_error():
raise TypeError("Either 'probs' or 'logits' must be specified, but not both.")
@constexpr
def raise_not_impl_error(name):

View File

@ -17,7 +17,7 @@ import numpy as np
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import logits_to_probs, probs_to_logits, check_tensor_type, cast_to_tensor
from ._utils.utils import logits_to_probs, probs_to_logits, check_type, check_tensor_type, cast_to_tensor, raise_probs_logits_error
class Categorical(Distribution):
@ -71,9 +71,11 @@ class Categorical(Distribution):
dtype=mstype.int32,
name="Categorical"):
param = dict(locals())
valid_dtype = mstype.int_type
check_type(dtype, valid_dtype, "Categorical")
super(Categorical, self).__init__(seed, dtype, name, param)
if (probs is None) == (logits is None):
raise ValueError("Either 'prob' or 'logits' must be specified, but not both.")
raise_probs_logits_error()
self.reduce_sum = P.ReduceSum(keep_dims=True)
self.log = P.Log()
self.exp = P.Exp()
@ -127,8 +129,7 @@ class Categorical(Distribution):
Returns:
Tensor, shape is shape(probs)[:-1] + sample_shape
"""
if not isinstance(sample_shape, tuple):
raise ValueError("sample shape must be a tuple")
self.checktuple(sample_shape, 'shape')
num_sample = 1
for i in sample_shape:
num_sample *= i
@ -136,7 +137,7 @@ class Categorical(Distribution):
samples = self.mutinomial(probs_2d, num_sample)
extend_shape = sample_shape
if len(self.shape(self._probs)) > 1:
extend_shape = self.shape(self._probs)[:-1] + sample_shape
extend_shape = sample_shape + self.shape(self._probs)[:-1]
return self.cast(self.reshape(samples, extend_shape), self.dtype)
def _broad_cast_shape(self, a, b):
@ -183,15 +184,16 @@ class Categorical(Distribution):
if value is not None:
check_tensor_type("value", value, [mstype.float32, bool, mstype.int32])
value = self.expandim(self.cast(value, mstype.float32), -1)
broad_shape = self._broad_cast_shape(value, self._logits)
index = cast_to_tensor(np.arange(self.shape(value)[0]).astype(np.float32))
index = self.expandim(index, -1)
logits = self._logits if self._logits.dim() == 1 else self.expandim(self._logits, 0)
broad_shape = self._broad_cast_shape(value, logits)
broad = P.BroadcastTo(broad_shape)
value = broad(value)[..., :1]
index = cast_to_tensor(np.arange(broad_shape[-1]).astype(np.float32))
index = self.expandim(index, -1)
index = broad(index)[..., :1]
value = self.concat((index, value))
value = self.cast(value, mstype.int32)
return self.gather(self._logits, value)
return self.gather(logits, value)
return None
def _entropy(self):
@ -209,7 +211,7 @@ class Categorical(Distribution):
Enumerate categories.
"""
num_events = self._num_events
values = cast_to_tensor(np.arange(num_events).astype(np.int32), mstype.int32)
values = cast_to_tensor(np.arange(num_events).astype(np.int32), mstype.float32)
values = self.reshape(values, (num_events, 1))
if expand:
values = P.BroadcastTo((num_events, self._batch_shape))(values)

View File

@ -204,14 +204,15 @@ def multinomial(inputs, num_sample, replacement=True, seed=0):
but must be non-negative, finite and have a non-zero sum.
Args:
input (Tensor) - the input tensor containing probabilities, must be 1 or 2 dims.
num_samples (int) - number of samples to draw.
replacement (bool, optional) - whether to draw with replacement or not, default True.
seed (int, optional) - used as entropy source for Random number engines generating pseudo-random numbers.
inputs (Tensor): the input tensor containing probabilities, must be 1 or 2 dims. With float32 data type.
num_sample (int): number of samples to draw.
replacement (bool, optional): whether to draw with replacement or not, default True.
seed (int, optional): used as entropy source for Random number engines generating pseudo-random numbers.
Must be non-negative. Default: 0.
Outputs:
Tensor. have the same rows with input, each row has num_samples sampled indices.
The dtype is float32.
Examples:
>>> input = Tensor([0, 9, 4, 0], mstype.float32)