forked from mindspore-Ecosystem/mindspore
fix Categorical
This commit is contained in:
parent
ac239b6506
commit
3a7e7802b0
|
@ -244,8 +244,8 @@ def logits_to_probs(logits, is_binary=False):
|
||||||
is_binary (bool)
|
is_binary (bool)
|
||||||
"""
|
"""
|
||||||
if is_binary:
|
if is_binary:
|
||||||
return nn.sigmoid()(logits)
|
return nn.Sigmoid()(logits)
|
||||||
return nn.softmax(axis=-1)(logits)
|
return nn.Softmax(axis=-1)(logits)
|
||||||
|
|
||||||
|
|
||||||
def clamp_probs(probs):
|
def clamp_probs(probs):
|
||||||
|
@ -300,6 +300,9 @@ def raise_none_error(name):
|
||||||
raise TypeError(f"the type {name} should be subclass of Tensor."
|
raise TypeError(f"the type {name} should be subclass of Tensor."
|
||||||
f" It should not be None since it is not specified during initialization.")
|
f" It should not be None since it is not specified during initialization.")
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def raise_probs_logits_error():
|
||||||
|
raise TypeError("Either 'probs' or 'logits' must be specified, but not both.")
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def raise_not_impl_error(name):
|
def raise_not_impl_error(name):
|
||||||
|
|
|
@ -17,7 +17,7 @@ import numpy as np
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from .distribution import Distribution
|
from .distribution import Distribution
|
||||||
from ._utils.utils import logits_to_probs, probs_to_logits, check_tensor_type, cast_to_tensor
|
from ._utils.utils import logits_to_probs, probs_to_logits, check_type, check_tensor_type, cast_to_tensor, raise_probs_logits_error
|
||||||
|
|
||||||
|
|
||||||
class Categorical(Distribution):
|
class Categorical(Distribution):
|
||||||
|
@ -71,9 +71,11 @@ class Categorical(Distribution):
|
||||||
dtype=mstype.int32,
|
dtype=mstype.int32,
|
||||||
name="Categorical"):
|
name="Categorical"):
|
||||||
param = dict(locals())
|
param = dict(locals())
|
||||||
|
valid_dtype = mstype.int_type
|
||||||
|
check_type(dtype, valid_dtype, "Categorical")
|
||||||
super(Categorical, self).__init__(seed, dtype, name, param)
|
super(Categorical, self).__init__(seed, dtype, name, param)
|
||||||
if (probs is None) == (logits is None):
|
if (probs is None) == (logits is None):
|
||||||
raise ValueError("Either 'prob' or 'logits' must be specified, but not both.")
|
raise_probs_logits_error()
|
||||||
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
||||||
self.log = P.Log()
|
self.log = P.Log()
|
||||||
self.exp = P.Exp()
|
self.exp = P.Exp()
|
||||||
|
@ -127,8 +129,7 @@ class Categorical(Distribution):
|
||||||
Returns:
|
Returns:
|
||||||
Tensor, shape is shape(probs)[:-1] + sample_shape
|
Tensor, shape is shape(probs)[:-1] + sample_shape
|
||||||
"""
|
"""
|
||||||
if not isinstance(sample_shape, tuple):
|
self.checktuple(sample_shape, 'shape')
|
||||||
raise ValueError("sample shape must be a tuple")
|
|
||||||
num_sample = 1
|
num_sample = 1
|
||||||
for i in sample_shape:
|
for i in sample_shape:
|
||||||
num_sample *= i
|
num_sample *= i
|
||||||
|
@ -136,7 +137,7 @@ class Categorical(Distribution):
|
||||||
samples = self.mutinomial(probs_2d, num_sample)
|
samples = self.mutinomial(probs_2d, num_sample)
|
||||||
extend_shape = sample_shape
|
extend_shape = sample_shape
|
||||||
if len(self.shape(self._probs)) > 1:
|
if len(self.shape(self._probs)) > 1:
|
||||||
extend_shape = self.shape(self._probs)[:-1] + sample_shape
|
extend_shape = sample_shape + self.shape(self._probs)[:-1]
|
||||||
return self.cast(self.reshape(samples, extend_shape), self.dtype)
|
return self.cast(self.reshape(samples, extend_shape), self.dtype)
|
||||||
|
|
||||||
def _broad_cast_shape(self, a, b):
|
def _broad_cast_shape(self, a, b):
|
||||||
|
@ -183,15 +184,16 @@ class Categorical(Distribution):
|
||||||
if value is not None:
|
if value is not None:
|
||||||
check_tensor_type("value", value, [mstype.float32, bool, mstype.int32])
|
check_tensor_type("value", value, [mstype.float32, bool, mstype.int32])
|
||||||
value = self.expandim(self.cast(value, mstype.float32), -1)
|
value = self.expandim(self.cast(value, mstype.float32), -1)
|
||||||
broad_shape = self._broad_cast_shape(value, self._logits)
|
index = cast_to_tensor(np.arange(self.shape(value)[0]).astype(np.float32))
|
||||||
|
index = self.expandim(index, -1)
|
||||||
|
logits = self._logits if self._logits.dim() == 1 else self.expandim(self._logits, 0)
|
||||||
|
broad_shape = self._broad_cast_shape(value, logits)
|
||||||
broad = P.BroadcastTo(broad_shape)
|
broad = P.BroadcastTo(broad_shape)
|
||||||
value = broad(value)[..., :1]
|
value = broad(value)[..., :1]
|
||||||
index = cast_to_tensor(np.arange(broad_shape[-1]).astype(np.float32))
|
|
||||||
index = self.expandim(index, -1)
|
|
||||||
index = broad(index)[..., :1]
|
index = broad(index)[..., :1]
|
||||||
value = self.concat((index, value))
|
value = self.concat((index, value))
|
||||||
value = self.cast(value, mstype.int32)
|
value = self.cast(value, mstype.int32)
|
||||||
return self.gather(self._logits, value)
|
return self.gather(logits, value)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _entropy(self):
|
def _entropy(self):
|
||||||
|
@ -209,7 +211,7 @@ class Categorical(Distribution):
|
||||||
Enumerate categories.
|
Enumerate categories.
|
||||||
"""
|
"""
|
||||||
num_events = self._num_events
|
num_events = self._num_events
|
||||||
values = cast_to_tensor(np.arange(num_events).astype(np.int32), mstype.int32)
|
values = cast_to_tensor(np.arange(num_events).astype(np.int32), mstype.float32)
|
||||||
values = self.reshape(values, (num_events, 1))
|
values = self.reshape(values, (num_events, 1))
|
||||||
if expand:
|
if expand:
|
||||||
values = P.BroadcastTo((num_events, self._batch_shape))(values)
|
values = P.BroadcastTo((num_events, self._batch_shape))(values)
|
||||||
|
|
|
@ -204,14 +204,15 @@ def multinomial(inputs, num_sample, replacement=True, seed=0):
|
||||||
but must be non-negative, finite and have a non-zero sum.
|
but must be non-negative, finite and have a non-zero sum.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Tensor) - the input tensor containing probabilities, must be 1 or 2 dims.
|
inputs (Tensor): the input tensor containing probabilities, must be 1 or 2 dims. With float32 data type.
|
||||||
num_samples (int) - number of samples to draw.
|
num_sample (int): number of samples to draw.
|
||||||
replacement (bool, optional) - whether to draw with replacement or not, default True.
|
replacement (bool, optional): whether to draw with replacement or not, default True.
|
||||||
seed (int, optional) - used as entropy source for Random number engines generating pseudo-random numbers.
|
seed (int, optional): used as entropy source for Random number engines generating pseudo-random numbers.
|
||||||
Must be non-negative. Default: 0.
|
Must be non-negative. Default: 0.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
Tensor. have the same rows with input, each row has num_samples sampled indices.
|
Tensor. have the same rows with input, each row has num_samples sampled indices.
|
||||||
|
The dtype is float32.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> input = Tensor([0, 9, 4, 0], mstype.float32)
|
>>> input = Tensor([0, 9, 4, 0], mstype.float32)
|
||||||
|
|
Loading…
Reference in New Issue