forked from OSSInnovation/mindspore
fix Categorical
This commit is contained in:
parent
ac239b6506
commit
3a7e7802b0
|
@ -244,8 +244,8 @@ def logits_to_probs(logits, is_binary=False):
|
|||
is_binary (bool)
|
||||
"""
|
||||
if is_binary:
|
||||
return nn.sigmoid()(logits)
|
||||
return nn.softmax(axis=-1)(logits)
|
||||
return nn.Sigmoid()(logits)
|
||||
return nn.Softmax(axis=-1)(logits)
|
||||
|
||||
|
||||
def clamp_probs(probs):
|
||||
|
@ -300,6 +300,9 @@ def raise_none_error(name):
|
|||
raise TypeError(f"the type {name} should be subclass of Tensor."
|
||||
f" It should not be None since it is not specified during initialization.")
|
||||
|
||||
@constexpr
|
||||
def raise_probs_logits_error():
|
||||
raise TypeError("Either 'probs' or 'logits' must be specified, but not both.")
|
||||
|
||||
@constexpr
|
||||
def raise_not_impl_error(name):
|
||||
|
|
|
@ -17,7 +17,7 @@ import numpy as np
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import logits_to_probs, probs_to_logits, check_tensor_type, cast_to_tensor
|
||||
from ._utils.utils import logits_to_probs, probs_to_logits, check_type, check_tensor_type, cast_to_tensor, raise_probs_logits_error
|
||||
|
||||
|
||||
class Categorical(Distribution):
|
||||
|
@ -71,9 +71,11 @@ class Categorical(Distribution):
|
|||
dtype=mstype.int32,
|
||||
name="Categorical"):
|
||||
param = dict(locals())
|
||||
valid_dtype = mstype.int_type
|
||||
check_type(dtype, valid_dtype, "Categorical")
|
||||
super(Categorical, self).__init__(seed, dtype, name, param)
|
||||
if (probs is None) == (logits is None):
|
||||
raise ValueError("Either 'prob' or 'logits' must be specified, but not both.")
|
||||
raise_probs_logits_error()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
||||
self.log = P.Log()
|
||||
self.exp = P.Exp()
|
||||
|
@ -127,8 +129,7 @@ class Categorical(Distribution):
|
|||
Returns:
|
||||
Tensor, shape is shape(probs)[:-1] + sample_shape
|
||||
"""
|
||||
if not isinstance(sample_shape, tuple):
|
||||
raise ValueError("sample shape must be a tuple")
|
||||
self.checktuple(sample_shape, 'shape')
|
||||
num_sample = 1
|
||||
for i in sample_shape:
|
||||
num_sample *= i
|
||||
|
@ -136,7 +137,7 @@ class Categorical(Distribution):
|
|||
samples = self.mutinomial(probs_2d, num_sample)
|
||||
extend_shape = sample_shape
|
||||
if len(self.shape(self._probs)) > 1:
|
||||
extend_shape = self.shape(self._probs)[:-1] + sample_shape
|
||||
extend_shape = sample_shape + self.shape(self._probs)[:-1]
|
||||
return self.cast(self.reshape(samples, extend_shape), self.dtype)
|
||||
|
||||
def _broad_cast_shape(self, a, b):
|
||||
|
@ -183,15 +184,16 @@ class Categorical(Distribution):
|
|||
if value is not None:
|
||||
check_tensor_type("value", value, [mstype.float32, bool, mstype.int32])
|
||||
value = self.expandim(self.cast(value, mstype.float32), -1)
|
||||
broad_shape = self._broad_cast_shape(value, self._logits)
|
||||
index = cast_to_tensor(np.arange(self.shape(value)[0]).astype(np.float32))
|
||||
index = self.expandim(index, -1)
|
||||
logits = self._logits if self._logits.dim() == 1 else self.expandim(self._logits, 0)
|
||||
broad_shape = self._broad_cast_shape(value, logits)
|
||||
broad = P.BroadcastTo(broad_shape)
|
||||
value = broad(value)[..., :1]
|
||||
index = cast_to_tensor(np.arange(broad_shape[-1]).astype(np.float32))
|
||||
index = self.expandim(index, -1)
|
||||
index = broad(index)[..., :1]
|
||||
value = self.concat((index, value))
|
||||
value = self.cast(value, mstype.int32)
|
||||
return self.gather(self._logits, value)
|
||||
return self.gather(logits, value)
|
||||
return None
|
||||
|
||||
def _entropy(self):
|
||||
|
@ -209,7 +211,7 @@ class Categorical(Distribution):
|
|||
Enumerate categories.
|
||||
"""
|
||||
num_events = self._num_events
|
||||
values = cast_to_tensor(np.arange(num_events).astype(np.int32), mstype.int32)
|
||||
values = cast_to_tensor(np.arange(num_events).astype(np.int32), mstype.float32)
|
||||
values = self.reshape(values, (num_events, 1))
|
||||
if expand:
|
||||
values = P.BroadcastTo((num_events, self._batch_shape))(values)
|
||||
|
|
|
@ -204,14 +204,15 @@ def multinomial(inputs, num_sample, replacement=True, seed=0):
|
|||
but must be non-negative, finite and have a non-zero sum.
|
||||
|
||||
Args:
|
||||
input (Tensor) - the input tensor containing probabilities, must be 1 or 2 dims.
|
||||
num_samples (int) - number of samples to draw.
|
||||
replacement (bool, optional) - whether to draw with replacement or not, default True.
|
||||
seed (int, optional) - used as entropy source for Random number engines generating pseudo-random numbers.
|
||||
inputs (Tensor): the input tensor containing probabilities, must be 1 or 2 dims. With float32 data type.
|
||||
num_sample (int): number of samples to draw.
|
||||
replacement (bool, optional): whether to draw with replacement or not, default True.
|
||||
seed (int, optional): used as entropy source for Random number engines generating pseudo-random numbers.
|
||||
Must be non-negative. Default: 0.
|
||||
|
||||
Outputs:
|
||||
Tensor. have the same rows with input, each row has num_samples sampled indices.
|
||||
The dtype is float32.
|
||||
|
||||
Examples:
|
||||
>>> input = Tensor([0, 9, 4, 0], mstype.float32)
|
||||
|
|
Loading…
Reference in New Issue