fix Categorical

This commit is contained in:
baihuawei 2020-08-27 11:17:42 +08:00
parent ac239b6506
commit 3a7e7802b0
3 changed files with 22 additions and 16 deletions

View File

@ -244,8 +244,8 @@ def logits_to_probs(logits, is_binary=False):
is_binary (bool) is_binary (bool)
""" """
if is_binary: if is_binary:
return nn.sigmoid()(logits) return nn.Sigmoid()(logits)
return nn.softmax(axis=-1)(logits) return nn.Softmax(axis=-1)(logits)
def clamp_probs(probs): def clamp_probs(probs):
@ -300,6 +300,9 @@ def raise_none_error(name):
raise TypeError(f"the type {name} should be subclass of Tensor." raise TypeError(f"the type {name} should be subclass of Tensor."
f" It should not be None since it is not specified during initialization.") f" It should not be None since it is not specified during initialization.")
@constexpr
def raise_probs_logits_error():
raise TypeError("Either 'probs' or 'logits' must be specified, but not both.")
@constexpr @constexpr
def raise_not_impl_error(name): def raise_not_impl_error(name):

View File

@ -17,7 +17,7 @@ import numpy as np
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from .distribution import Distribution from .distribution import Distribution
from ._utils.utils import logits_to_probs, probs_to_logits, check_tensor_type, cast_to_tensor from ._utils.utils import logits_to_probs, probs_to_logits, check_type, check_tensor_type, cast_to_tensor, raise_probs_logits_error
class Categorical(Distribution): class Categorical(Distribution):
@ -71,9 +71,11 @@ class Categorical(Distribution):
dtype=mstype.int32, dtype=mstype.int32,
name="Categorical"): name="Categorical"):
param = dict(locals()) param = dict(locals())
valid_dtype = mstype.int_type
check_type(dtype, valid_dtype, "Categorical")
super(Categorical, self).__init__(seed, dtype, name, param) super(Categorical, self).__init__(seed, dtype, name, param)
if (probs is None) == (logits is None): if (probs is None) == (logits is None):
raise ValueError("Either 'prob' or 'logits' must be specified, but not both.") raise_probs_logits_error()
self.reduce_sum = P.ReduceSum(keep_dims=True) self.reduce_sum = P.ReduceSum(keep_dims=True)
self.log = P.Log() self.log = P.Log()
self.exp = P.Exp() self.exp = P.Exp()
@ -127,8 +129,7 @@ class Categorical(Distribution):
Returns: Returns:
Tensor, shape is shape(probs)[:-1] + sample_shape Tensor, shape is shape(probs)[:-1] + sample_shape
""" """
if not isinstance(sample_shape, tuple): self.checktuple(sample_shape, 'shape')
raise ValueError("sample shape must be a tuple")
num_sample = 1 num_sample = 1
for i in sample_shape: for i in sample_shape:
num_sample *= i num_sample *= i
@ -136,7 +137,7 @@ class Categorical(Distribution):
samples = self.mutinomial(probs_2d, num_sample) samples = self.mutinomial(probs_2d, num_sample)
extend_shape = sample_shape extend_shape = sample_shape
if len(self.shape(self._probs)) > 1: if len(self.shape(self._probs)) > 1:
extend_shape = self.shape(self._probs)[:-1] + sample_shape extend_shape = sample_shape + self.shape(self._probs)[:-1]
return self.cast(self.reshape(samples, extend_shape), self.dtype) return self.cast(self.reshape(samples, extend_shape), self.dtype)
def _broad_cast_shape(self, a, b): def _broad_cast_shape(self, a, b):
@ -183,15 +184,16 @@ class Categorical(Distribution):
if value is not None: if value is not None:
check_tensor_type("value", value, [mstype.float32, bool, mstype.int32]) check_tensor_type("value", value, [mstype.float32, bool, mstype.int32])
value = self.expandim(self.cast(value, mstype.float32), -1) value = self.expandim(self.cast(value, mstype.float32), -1)
broad_shape = self._broad_cast_shape(value, self._logits) index = cast_to_tensor(np.arange(self.shape(value)[0]).astype(np.float32))
index = self.expandim(index, -1)
logits = self._logits if self._logits.dim() == 1 else self.expandim(self._logits, 0)
broad_shape = self._broad_cast_shape(value, logits)
broad = P.BroadcastTo(broad_shape) broad = P.BroadcastTo(broad_shape)
value = broad(value)[..., :1] value = broad(value)[..., :1]
index = cast_to_tensor(np.arange(broad_shape[-1]).astype(np.float32))
index = self.expandim(index, -1)
index = broad(index)[..., :1] index = broad(index)[..., :1]
value = self.concat((index, value)) value = self.concat((index, value))
value = self.cast(value, mstype.int32) value = self.cast(value, mstype.int32)
return self.gather(self._logits, value) return self.gather(logits, value)
return None return None
def _entropy(self): def _entropy(self):
@ -209,7 +211,7 @@ class Categorical(Distribution):
Enumerate categories. Enumerate categories.
""" """
num_events = self._num_events num_events = self._num_events
values = cast_to_tensor(np.arange(num_events).astype(np.int32), mstype.int32) values = cast_to_tensor(np.arange(num_events).astype(np.int32), mstype.float32)
values = self.reshape(values, (num_events, 1)) values = self.reshape(values, (num_events, 1))
if expand: if expand:
values = P.BroadcastTo((num_events, self._batch_shape))(values) values = P.BroadcastTo((num_events, self._batch_shape))(values)

View File

@ -204,14 +204,15 @@ def multinomial(inputs, num_sample, replacement=True, seed=0):
but must be non-negative, finite and have a non-zero sum. but must be non-negative, finite and have a non-zero sum.
Args: Args:
input (Tensor) - the input tensor containing probabilities, must be 1 or 2 dims. inputs (Tensor): the input tensor containing probabilities, must be 1 or 2 dims. With float32 data type.
num_samples (int) - number of samples to draw. num_sample (int): number of samples to draw.
replacement (bool, optional) - whether to draw with replacement or not, default True. replacement (bool, optional): whether to draw with replacement or not, default True.
seed (int, optional) - used as entropy source for Random number engines generating pseudo-random numbers. seed (int, optional): used as entropy source for Random number engines generating pseudo-random numbers.
Must be non-negative. Default: 0. Must be non-negative. Default: 0.
Outputs: Outputs:
Tensor. have the same rows with input, each row has num_samples sampled indices. Tensor. have the same rows with input, each row has num_samples sampled indices.
The dtype is float32.
Examples: Examples:
>>> input = Tensor([0, 9, 4, 0], mstype.float32) >>> input = Tensor([0, 9, 4, 0], mstype.float32)