added categorical distribution
This commit is contained in:
parent
c1b9efe8e6
commit
877b561e77
|
@ -158,6 +158,18 @@ def check_prob(p):
|
|||
if not comp.all():
|
||||
raise ValueError('Probabilities should be less than one')
|
||||
|
||||
def check_sum_equal_one(probs):
|
||||
prob_sum = np.sum(probs.asnumpy(), axis=-1)
|
||||
comp = np.equal(np.ones(prob_sum.shape), prob_sum)
|
||||
if not comp.all():
|
||||
raise ValueError('Probabilities for each category should sum to one for Categorical distribution.')
|
||||
|
||||
def check_rank(probs):
|
||||
"""
|
||||
Used in categorical distribution. check Rank >=1.
|
||||
"""
|
||||
if probs.asnumpy().ndim == 0:
|
||||
raise ValueError('probs for Categorical distribution must have rank >= 1.')
|
||||
|
||||
def logits_to_probs(logits, is_binary=False):
|
||||
"""
|
||||
|
|
|
@ -13,108 +13,150 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Categorical Distribution"""
|
||||
import numpy as np
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common import dtype as mstype
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import logits_to_probs, probs_to_logits, check_type, cast_to_tensor, \
|
||||
raise_probs_logits_error
|
||||
from ._utils.utils import check_prob, check_sum_equal_one, check_type, check_rank,\
|
||||
check_distribution_name, raise_not_implemented_util
|
||||
from ._utils.custom_ops import exp_generic, log_generic, broadcast_to
|
||||
|
||||
|
||||
class Categorical(Distribution):
|
||||
"""
|
||||
Create a categorical distribution parameterized by either probabilities or logits (but not both).
|
||||
Create a categorical distribution parameterized by event probabilities.
|
||||
|
||||
Args:
|
||||
probs (Tensor, list, numpy.ndarray, Parameter): Event probabilities.
|
||||
logits (Tensor, list, numpy.ndarray, Parameter, float): Event log-odds.
|
||||
seed (int): The global seed is used in sampling. Global seed is used if it is None. Default: None.
|
||||
dtype (mindspore.dtype): The type of the distribution. Default: mstype.int32.
|
||||
name (str): The name of the distribution. Default: Categorical.
|
||||
|
||||
Note:
|
||||
`probs` must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1.
|
||||
`probs` must have rank at least 1, values are proper probabilities and sum to 1.
|
||||
|
||||
Examples:
|
||||
>>> # To initialize a Categorical distribution of prob is [0.5, 0.5]
|
||||
>>> # To initialize a Categorical distribution of probs [0.5, 0.5]
|
||||
>>> import mindspore.nn.probability.distribution as msd
|
||||
>>> b = msd.Categorical(probs = [0.5, 0.5], dtype=mstype.int32)
|
||||
>>>
|
||||
>>> # To use Categorical in a network
|
||||
>>> # To use a Categorical distribution in a network
|
||||
>>> class net(Cell):
|
||||
>>> def __init__(self, probs):
|
||||
>>> super(net, self).__init__():
|
||||
>>> self.ca = msd.Categorical(probs=probs, dtype=mstype.int32)
|
||||
>>> self.ca = msd.Categorical(probs=[0.2, 0.8], dtype=mstype.int32)
|
||||
>>> self.ca1 = msd.Categorical(probs=[0.2, 0.8], dtype=mstype.int32)
|
||||
>>>
|
||||
>>> # All the following calls in construct are valid
|
||||
>>> def construct(self, value):
|
||||
>>>
|
||||
>>> # Similar calls can be made to logits
|
||||
>>> ans = self.ca.probs
|
||||
>>> # value must be Tensor(mstype.float32, bool, mstype.int32)
|
||||
>>> ans = self.ca.log_prob(value)
|
||||
>>> # Private interfaces of probability functions corresponding to public interfaces, including
|
||||
>>> # `prob`, `log_prob`, `cdf`, `log_cdf`, `survival_function`, and `log_survival`, are the same as follows.
|
||||
>>> # Args:
|
||||
>>> # value (Tensor): the value to be evaluated.
|
||||
>>> # probs (Tensor): event probabilities. Default: self.probs.
|
||||
>>>
|
||||
>>> # Usage of enumerate_support
|
||||
>>> ans = self.ca.enumerate_support()
|
||||
>>> # Examples of `prob`.
|
||||
>>> # Similar calls can be made to other probability functions
|
||||
>>> # by replacing `prob` by the name of the function.
|
||||
>>> ans = self.ca.prob(value)
|
||||
>>> # Evaluate `prob` with respect to distribution b.
|
||||
>>> ans = self.ca.prob(value, probs_b)
|
||||
>>> # `probs` must be passed in during function calls.
|
||||
>>> ans = self.ca1.prob(value, probs_a)
|
||||
>>>
|
||||
>>> # Usage of entropy
|
||||
>>> ans = self.ca.entropy()
|
||||
>>> # Functions `mean`, `sd`, `var`, and `entropy` have the same arguments.
|
||||
>>> # Args:
|
||||
>>> # probs (Tensor): event probabilities. Default: self.probs.
|
||||
>>>
|
||||
>>> # Sample
|
||||
>>> # Examples of `mean`. `sd`, `var`, and `entropy` are similar.
|
||||
>>> ans = self.ca.mean() # return 0.8
|
||||
>>> ans = self.ca.mean(probs_b)
|
||||
>>> # `probs` must be passed in during function calls.
|
||||
>>> ans = self.ca1.mean(probs_a)
|
||||
>>>
|
||||
>>> # Interfaces of `kl_loss` and `cross_entropy` are the same as follows:
|
||||
>>> # Args:
|
||||
>>> # dist (str): the name of the distribution. Only 'Categorical' is supported.
|
||||
>>> # probs_b (Tensor): event probabilities of distribution b.
|
||||
>>> # probs (Tensor): event probabilities of distribution a. Default: self.probs.
|
||||
>>>
|
||||
>>> # Examples of kl_loss. `cross_entropy` is similar.
|
||||
>>> ans = self.ca.kl_loss('Categorical', probs_b)
|
||||
>>> ans = self.ca.kl_loss('Categorical', probs_b, probs_a)
|
||||
>>> # An additional `probs` must be passed in.
|
||||
>>> ans = self.ca1.kl_loss('Categorical', probs_b, probs_a)
|
||||
>>>
|
||||
>>> # Examples of `sample`.
|
||||
>>> # Args:
|
||||
>>> # shape (tuple): the shape of the sample. Default: ().
|
||||
>>> # probs (Tensor): event probabilities. Default: self.probs.
|
||||
>>> ans = self.ca.sample()
|
||||
>>> ans = self.ca.sample((2,3))
|
||||
>>> ans = self.ca.sample((2,))
|
||||
>>> ans = self.b1.sample((2,3), probs_b)
|
||||
>>> ans = self.b2.sample((2,3), probs_a)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
probs=None,
|
||||
logits=None,
|
||||
seed=None,
|
||||
dtype=mstype.int32,
|
||||
name="Categorical"):
|
||||
param = dict(locals())
|
||||
param['param_dict'] = {'probs': probs, 'logits': logits}
|
||||
param['param_dict'] = {'probs': probs}
|
||||
valid_dtype = mstype.int_type
|
||||
check_type(dtype, valid_dtype, "Categorical")
|
||||
super(Categorical, self).__init__(seed, dtype, name, param)
|
||||
if (probs is None) == (logits is None):
|
||||
raise_probs_logits_error()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
||||
self.reduce_sum1 = P.ReduceSum(keep_dims=False)
|
||||
self.log = P.Log()
|
||||
self.exp = P.Exp()
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.div = P.RealDiv()
|
||||
self.size = P.Size()
|
||||
self.mutinomial = P.Multinomial(seed=self.seed)
|
||||
self.cast = P.Cast()
|
||||
self.expandim = P.ExpandDims()
|
||||
self.gather = P.GatherNd()
|
||||
self.concat = P.Concat(-1)
|
||||
self.transpose = P.Transpose()
|
||||
if probs is not None:
|
||||
self._probs = cast_to_tensor(probs, mstype.float32)
|
||||
input_sum = self.reduce_sum(self._probs, -1)
|
||||
self._probs = self.div(self._probs, input_sum)
|
||||
self._logits = probs_to_logits(self._probs)
|
||||
self._param = self._probs
|
||||
else:
|
||||
self._logits = cast_to_tensor(logits, mstype.float32)
|
||||
input_sum = self.reduce_sum(self.exp(self._logits), -1)
|
||||
self._logits = self._logits - self.log(input_sum)
|
||||
self._probs = logits_to_probs(self._logits)
|
||||
self._param = self._logits
|
||||
self._num_events = self.shape(self._param)[-1]
|
||||
self._param2d = self.reshape(self._param, (-1, self._num_events))
|
||||
self._batch_shape = self.shape(self._param)[:-1]
|
||||
self._batch_shape_n = (1,) * len(self._batch_shape)
|
||||
|
||||
@property
|
||||
def logits(self):
|
||||
"""
|
||||
Return the logits.
|
||||
"""
|
||||
return self._logits
|
||||
self._probs = self._add_parameter(probs, 'probs')
|
||||
if self.probs is not None:
|
||||
check_rank(self.probs)
|
||||
check_prob(self.probs)
|
||||
check_sum_equal_one(self.probs)
|
||||
|
||||
# update is_scalar_batch and broadcast_shape
|
||||
# drop one dimension
|
||||
if self.probs.shape[:-1] == ():
|
||||
self._is_scalar_batch = True
|
||||
self._broadcast_shape = self._broadcast_shape[:-1]
|
||||
|
||||
self.argmax = P.Argmax()
|
||||
self.broadcast = broadcast_to
|
||||
self.cast = P.Cast()
|
||||
self.clip_by_value = C.clip_by_value
|
||||
self.concat = P.Concat(-1)
|
||||
self.cumsum = P.CumSum()
|
||||
self.dtypeop = P.DType()
|
||||
self.exp = exp_generic
|
||||
self.expand_dim = P.ExpandDims()
|
||||
self.fill = P.Fill()
|
||||
self.floor = P.Floor()
|
||||
self.gather = P.GatherNd()
|
||||
self.less = P.Less()
|
||||
self.log = log_generic
|
||||
self.log_softmax = P.LogSoftmax()
|
||||
self.logicor = P.LogicalOr()
|
||||
self.multinomial = P.Multinomial(seed=self.seed)
|
||||
self.reshape = P.Reshape()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
||||
self.select = P.Select()
|
||||
self.shape = P.Shape()
|
||||
self.softmax = P.Softmax()
|
||||
self.squeeze = P.Squeeze()
|
||||
self.square = P.Square()
|
||||
self.transpose = P.Transpose()
|
||||
|
||||
self.index_type = mstype.int32
|
||||
|
||||
|
||||
def extend_repr(self):
|
||||
if self.is_scalar_batch:
|
||||
str_info = f'probs = {self.probs}'
|
||||
else:
|
||||
str_info = f'batch_shape = {self._broadcast_shape}'
|
||||
return str_info
|
||||
|
||||
@property
|
||||
def probs(self):
|
||||
|
@ -123,68 +165,214 @@ class Categorical(Distribution):
|
|||
"""
|
||||
return self._probs
|
||||
|
||||
def _sample(self, sample_shape=()):
|
||||
def _mean(self, probs=None):
|
||||
r"""
|
||||
.. math::
|
||||
E[X] = \sum_{i=0}^{num_classes-1} i*p_i
|
||||
"""
|
||||
Sampling.
|
||||
probs = self._check_param_type(probs)
|
||||
num_classes = self.shape(probs)[-1]
|
||||
index = nn.Range(0., num_classes, 1.)()
|
||||
return self.reduce_sum(index * probs, -1)
|
||||
|
||||
def _mode(self, probs=None):
|
||||
probs = self._check_param_type(probs)
|
||||
mode = self.cast(self.argmax(probs), self.dtype)
|
||||
return self.squeeze(mode)
|
||||
|
||||
def _var(self, probs=None):
|
||||
r"""
|
||||
.. math::
|
||||
VAR(X) = E[X^{2}] - (E[X])^{2}
|
||||
"""
|
||||
probs = self._check_param_type(probs)
|
||||
num_classes = self.shape(probs)[-1]
|
||||
index = nn.Range(0., num_classes, 1.)()
|
||||
return self.reduce_sum(self.square(index) * probs, -1) -\
|
||||
self.square(self.reduce_sum(index * probs, -1))
|
||||
|
||||
def _entropy(self, probs=None):
|
||||
r"""
|
||||
Evaluate entropy.
|
||||
|
||||
.. math::
|
||||
H(X) = -\sum(logits * probs)
|
||||
"""
|
||||
probs = self._check_param_type(probs)
|
||||
logits = self.log(probs)
|
||||
return self.squeeze(-self.reduce_sum(logits * probs, -1))
|
||||
|
||||
def _kl_loss(self, dist, probs_b, probs=None):
|
||||
"""
|
||||
Evaluate KL divergence between Categorical distributions.
|
||||
|
||||
Args:
|
||||
sample_shape (tuple): The shape of the sample. Default: ().
|
||||
|
||||
Returns:
|
||||
Tensor, shape is shape(probs)[:-1] + sample_shape
|
||||
dist (str): The type of the distributions. Should be "Categorical" in this case.
|
||||
probs_b (Tensor): Event probabilities of distribution b.
|
||||
probs (Tensor): Event probabilities of distribution a. Default: self.probs.
|
||||
"""
|
||||
self.checktuple(sample_shape, 'shape')
|
||||
num_sample = 1
|
||||
for i in sample_shape:
|
||||
num_sample *= i
|
||||
probs_2d = self.reshape(self._probs, (-1, self._num_events))
|
||||
samples = self.mutinomial(probs_2d, num_sample)
|
||||
samples = self.transpose(samples, (1, 0))
|
||||
extend_shape = sample_shape
|
||||
if len(self.shape(self._probs)) > 1:
|
||||
extend_shape = sample_shape + self.shape(self._probs)[:-1]
|
||||
return self.cast(self.reshape(samples, extend_shape), self.dtype)
|
||||
check_distribution_name(dist, 'Categorical')
|
||||
probs_b = self._check_value(probs_b, 'probs_b')
|
||||
probs_b = self.cast(probs_b, self.parameter_type)
|
||||
probs_a = self._check_param_type(probs)
|
||||
logits_a = self.log(probs_a)
|
||||
logits_b = self.log(probs_b)
|
||||
return self.squeeze(-self.reduce_sum(
|
||||
self.softmax(logits_a) * (self.log_softmax(logits_a) - (self.log_softmax(logits_b))), -1))
|
||||
|
||||
def _log_prob(self, value):
|
||||
def _cross_entropy(self, dist, probs_b, probs=None):
|
||||
"""
|
||||
Evaluate cross entropy between Categorical distributions.
|
||||
|
||||
Args:
|
||||
dist (str): The type of the distributions. Should be "Categorical" in this case.
|
||||
probs_b (Tensor): Event probabilities of distribution b.
|
||||
probs (Tensor): Event probabilities of distribution a. Default: self.probs.
|
||||
"""
|
||||
check_distribution_name(dist, 'Categorical')
|
||||
return self._entropy(probs) + self._kl_loss(dist, probs_b, probs)
|
||||
|
||||
def _log_prob(self, value, probs=None):
|
||||
r"""
|
||||
Evaluate log probability.
|
||||
|
||||
Args:
|
||||
value (Tensor): The value to be evaluated.
|
||||
probs (Tensor): Event probabilities. Default: self.probs.
|
||||
"""
|
||||
value = self._check_value(value, 'value')
|
||||
value = self.expandim(self.cast(value, mstype.float32), -1)
|
||||
broad_shape = self.shape(value + self._logits)
|
||||
broad = P.BroadcastTo(broad_shape)
|
||||
logits_pmf = self.reshape(broad(self._logits), (-1, broad_shape[-1]))
|
||||
value = self.reshape(broad(value)[..., :1], (-1, 1))
|
||||
index = nn.Range(0., self.shape(value)[0], 1)()
|
||||
index = self.reshape(index, (-1, 1))
|
||||
value = self.concat((index, value))
|
||||
value = self.cast(value, mstype.int32)
|
||||
return self.reshape(self.gather(logits_pmf, value), broad_shape[:-1])
|
||||
value = self.cast(value, self.parameter_type)
|
||||
probs = self._check_param_type(probs)
|
||||
logits = self.log(probs)
|
||||
|
||||
def _entropy(self):
|
||||
# handle the case when value is of shape () and probs is a scalar batch
|
||||
drop_dim = False
|
||||
if self.shape(value) == () and self.shape(probs)[:-1] == ():
|
||||
drop_dim = True
|
||||
# manually add one more dimension: () -> (1,)
|
||||
# drop this dimension before return
|
||||
value = self.expand_dim(value, -1)
|
||||
|
||||
value = self.expand_dim(value, -1)
|
||||
|
||||
broadcast_shape_tensor = logits * value
|
||||
broadcast_shape = self.shape(broadcast_shape_tensor)
|
||||
# broadcast_shape (N, C)
|
||||
num_classes = broadcast_shape[-1]
|
||||
label_shape = broadcast_shape[:-1]
|
||||
|
||||
# broadcasting logits and value
|
||||
# logit_pmf shape (num of labels, C)
|
||||
logits = self.broadcast(logits, broadcast_shape_tensor)
|
||||
value = self.broadcast(value, broadcast_shape_tensor)[..., :1]
|
||||
|
||||
# flatten value to shape (number of labels, 1)
|
||||
# clip value to be in range from 0 to num_classes -1 and cast into int32
|
||||
value = self.reshape(value, (-1, 1))
|
||||
out_of_bound = self.squeeze(self.logicor(\
|
||||
self.less(value, 0.0), self.less(num_classes-1, value)))
|
||||
value_clipped = self.clip_by_value(value, 0.0, num_classes - 1)
|
||||
value_clipped = self.cast(value_clipped, self.index_type)
|
||||
# create index from 0 ... NumOfLabels
|
||||
index = self.reshape(nn.Range(0, self.shape(value)[0], 1)(), (-1, 1))
|
||||
index = self.concat((index, value_clipped))
|
||||
|
||||
# index into logit_pmf, fill in out_of_bound places with -inf
|
||||
# reshape into label shape N
|
||||
logits_pmf = self.gather(self.reshape(logits, (-1, num_classes)), index)
|
||||
neg_inf = self.fill(self.dtypeop(logits_pmf), self.shape(logits_pmf), -np.inf)
|
||||
logits_pmf = self.select(out_of_bound, neg_inf, logits_pmf)
|
||||
ans = self.reshape(logits_pmf, label_shape)
|
||||
if drop_dim:
|
||||
return self.squeeze(ans)
|
||||
return ans
|
||||
|
||||
def _cdf(self, value, probs=None):
|
||||
r"""
|
||||
Evaluate entropy.
|
||||
Cumulative distribution function (cdf) of Categorical distributions.
|
||||
|
||||
.. math::
|
||||
H(X) = -\sum(logits * probs)
|
||||
"""
|
||||
p_log_p = self._logits * self._probs
|
||||
return self.reduce_sum1(-p_log_p, -1)
|
||||
Args:
|
||||
value (Tensor): The value to be evaluated.
|
||||
probs (Tensor): Event probabilities. Default: self.probs.
|
||||
"""
|
||||
value = self._check_value(value, 'value')
|
||||
value = self.cast(value, self.parameter_type)
|
||||
value = self.floor(value)
|
||||
probs = self._check_param_type(probs)
|
||||
|
||||
def enumerate_support(self, expand=True):
|
||||
r"""
|
||||
Enumerate categories.
|
||||
# handle the case when value is of shape () and probs is a scalar batch
|
||||
drop_dim = False
|
||||
if self.shape(value) == () and self.shape(probs)[:-1] == ():
|
||||
drop_dim = True
|
||||
# manually add one more dimension: () -> (1,)
|
||||
# drop this dimension before return
|
||||
value = self.expand_dim(value, -1)
|
||||
|
||||
Args:
|
||||
expand (Bool): Whether to expand.
|
||||
"""
|
||||
num_events = self._num_events
|
||||
values = nn.Range(0., num_events, 1)()
|
||||
values = self.reshape(values, (num_events,) + self._batch_shape_n)
|
||||
if expand:
|
||||
values = P.BroadcastTo((num_events,) + self._batch_shape)(values)
|
||||
values = self.cast(values, mstype.int32)
|
||||
return values
|
||||
value = self.expand_dim(value, -1)
|
||||
|
||||
broadcast_shape_tensor = probs * value
|
||||
broadcast_shape = self.shape(broadcast_shape_tensor)
|
||||
# broadcast_shape (N, C)
|
||||
num_classes = broadcast_shape[-1]
|
||||
label_shape = broadcast_shape[:-1]
|
||||
|
||||
probs = self.broadcast(probs, broadcast_shape_tensor)
|
||||
value = self.broadcast(value, broadcast_shape_tensor)[..., :1]
|
||||
|
||||
# flatten value to shape (number of labels, 1)
|
||||
value = self.reshape(value, (-1, 1))
|
||||
|
||||
# drop one dimension to match cdf
|
||||
# clip value to be in range from 0 to num_classes -1 and cast into int32
|
||||
less_than_zero = self.squeeze(self.less(value, 0.0))
|
||||
value_clipped = self.clip_by_value(value, 0.0, num_classes - 1)
|
||||
value_clipped = self.cast(value_clipped, self.index_type)
|
||||
|
||||
index = self.reshape(nn.Range(0, self.shape(value)[0], 1)(), (-1, 1))
|
||||
index = self.concat((index, value_clipped))
|
||||
|
||||
# reshape probs and fill less_than_zero places with 0
|
||||
probs = self.reshape(probs, (-1, num_classes))
|
||||
cdf = self.gather(self.cumsum(probs, 1), index)
|
||||
zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
|
||||
cdf = self.select(less_than_zero, zeros, cdf)
|
||||
cdf = self.reshape(cdf, label_shape)
|
||||
|
||||
if drop_dim:
|
||||
return self.squeeze(cdf)
|
||||
return cdf
|
||||
|
||||
def _sample(self, shape=(), probs=None):
|
||||
"""
|
||||
Sampling.
|
||||
|
||||
Args:
|
||||
shape (tuple): The shape of the sample. Default: ().
|
||||
probs (Tensor): Event probabilities. Default: self.probs.
|
||||
|
||||
Returns:
|
||||
Tensor, shape is shape(probs)[:-1] + sample_shape
|
||||
"""
|
||||
if self.device_target == 'Ascend':
|
||||
raise_not_implemented_util('On d backend, sample', self.name)
|
||||
shape = self.checktuple(shape, 'shape')
|
||||
probs = self._check_param_type(probs)
|
||||
num_classes = self.shape(probs)[-1]
|
||||
batch_shape = self.shape(probs)[:-1]
|
||||
|
||||
sample_shape = shape + batch_shape
|
||||
drop_dim = False
|
||||
if sample_shape == ():
|
||||
drop_dim = True
|
||||
sample_shape = (1,)
|
||||
|
||||
probs_2d = self.reshape(probs, (-1, num_classes))
|
||||
sample_tensor = self.fill(self.dtype, shape, 1.0)
|
||||
sample_tensor = self.reshape(sample_tensor, (-1, 1))
|
||||
num_sample = self.shape(sample_tensor)[0]
|
||||
samples = self.multinomial(probs_2d, num_sample)
|
||||
samples = self.squeeze(self.transpose(samples, (1, 0)))
|
||||
samples = self.cast(self.reshape(samples, sample_shape), self.dtype)
|
||||
if drop_dim:
|
||||
return self.squeeze(samples)
|
||||
return samples
|
||||
|
|
|
@ -96,6 +96,7 @@ class Distribution(Cell):
|
|||
self._set_cross_entropy()
|
||||
|
||||
self.context_mode = context.get_context('mode')
|
||||
self.device_target = context.get_context('device_target')
|
||||
self.checktuple = CheckTuple()
|
||||
self.checktensor = CheckTensor()
|
||||
self.broadcast = broadcast_to
|
||||
|
|
|
@ -0,0 +1,273 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test cases for cat distribution"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
from scipy import stats
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability.distribution as msd
|
||||
from mindspore import Tensor
|
||||
from mindspore import dtype
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
class Prob(nn.Cell):
|
||||
"""
|
||||
Test class: probability of categorical distribution.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Prob, self).__init__()
|
||||
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.c.prob(x_)
|
||||
|
||||
def test_pmf():
|
||||
"""
|
||||
Test pmf.
|
||||
"""
|
||||
expect_pmf = [0.7, 0.3, 0.7, 0.3, 0.3]
|
||||
pmf = Prob()
|
||||
x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32)
|
||||
output = pmf(x_)
|
||||
tol = 1e-6
|
||||
assert (np.abs(output.asnumpy() - expect_pmf) < tol).all()
|
||||
|
||||
|
||||
class LogProb(nn.Cell):
|
||||
"""
|
||||
Test class: log probability of categorical distribution.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(LogProb, self).__init__()
|
||||
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.c.log_prob(x_)
|
||||
|
||||
def test_log_likelihood():
|
||||
"""
|
||||
Test log_pmf.
|
||||
"""
|
||||
expect_logpmf = np.log([0.7, 0.3, 0.7, 0.3, 0.3])
|
||||
logprob = LogProb()
|
||||
x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32)
|
||||
output = logprob(x_)
|
||||
tol = 1e-6
|
||||
assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all()
|
||||
|
||||
class KL(nn.Cell):
|
||||
"""
|
||||
Test class: kl_loss between categorical distributions.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(KL, self).__init__()
|
||||
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.c.kl_loss('Categorical', x_)
|
||||
|
||||
def test_kl_loss():
|
||||
"""
|
||||
Test kl_loss.
|
||||
"""
|
||||
kl_loss = KL()
|
||||
output = kl_loss(Tensor([0.7, 0.3], dtype=dtype.float32))
|
||||
tol = 1e-6
|
||||
assert (np.abs(output.asnumpy()) < tol).all()
|
||||
|
||||
class Sampling(nn.Cell):
|
||||
"""
|
||||
Test class: sampling of categorical distribution.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Sampling, self).__init__()
|
||||
self.c = msd.Categorical([0.2, 0.1, 0.7], dtype=dtype.int32)
|
||||
self.shape = (2, 3)
|
||||
|
||||
def construct(self):
|
||||
return self.c.sample(self.shape)
|
||||
|
||||
def test_sample():
|
||||
"""
|
||||
Test sample.
|
||||
"""
|
||||
with pytest.raises(NotImplementedError):
|
||||
sample = Sampling()
|
||||
sample()
|
||||
|
||||
class Basics(nn.Cell):
|
||||
"""
|
||||
Test class: mean/var/mode of categorical distribution.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Basics, self).__init__()
|
||||
self.c = msd.Categorical([0.2, 0.1, 0.7], dtype=dtype.int32)
|
||||
|
||||
def construct(self):
|
||||
return self.c.mean(), self.c.var(), self.c.mode()
|
||||
|
||||
def test_basics():
|
||||
"""
|
||||
Test mean/variance/mode.
|
||||
"""
|
||||
basics = Basics()
|
||||
mean, var, mode = basics()
|
||||
expect_mean = 0 * 0.2 + 1 * 0.1 + 2 * 0.7
|
||||
expect_var = 0 * 0.2 + 1 * 0.1 + 4 * 0.7 - (expect_mean * expect_mean)
|
||||
expect_mode = 2
|
||||
tol = 1e-6
|
||||
assert (np.abs(mean.asnumpy() - expect_mean) < tol).all()
|
||||
assert (np.abs(var.asnumpy() - expect_var) < tol).all()
|
||||
assert (np.abs(mode.asnumpy() - expect_mode) < tol).all()
|
||||
|
||||
|
||||
class CDF(nn.Cell):
|
||||
"""
|
||||
Test class: cdf of categorical distributions.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(CDF, self).__init__()
|
||||
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.c.cdf(x_)
|
||||
|
||||
def test_cdf():
|
||||
"""
|
||||
Test cdf.
|
||||
"""
|
||||
expect_cdf = [0.7, 0.7, 1, 0.7, 1]
|
||||
x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(np.int32), dtype=dtype.float32)
|
||||
cdf = CDF()
|
||||
output = cdf(x_)
|
||||
tol = 1e-6
|
||||
assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
|
||||
|
||||
class LogCDF(nn.Cell):
|
||||
"""
|
||||
Test class: log cdf of categorical distributions.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(LogCDF, self).__init__()
|
||||
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.c.log_cdf(x_)
|
||||
|
||||
def test_logcdf():
|
||||
"""
|
||||
Test log_cdf.
|
||||
"""
|
||||
expect_logcdf = np.log([0.7, 0.7, 1, 0.7, 1])
|
||||
x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(np.int32), dtype=dtype.float32)
|
||||
logcdf = LogCDF()
|
||||
output = logcdf(x_)
|
||||
tol = 1e-6
|
||||
assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all()
|
||||
|
||||
|
||||
class SF(nn.Cell):
|
||||
"""
|
||||
Test class: survival function of categorical distributions.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(SF, self).__init__()
|
||||
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.c.survival_function(x_)
|
||||
|
||||
def test_survival():
|
||||
"""
|
||||
Test survival funciton.
|
||||
"""
|
||||
expect_survival = [0.3, 0., 0., 0.3, 0.3]
|
||||
x_ = Tensor(np.array([0, 1, 1, 0, 0]).astype(np.int32), dtype=dtype.float32)
|
||||
sf = SF()
|
||||
output = sf(x_)
|
||||
tol = 1e-6
|
||||
assert (np.abs(output.asnumpy() - expect_survival) < tol).all()
|
||||
|
||||
|
||||
class LogSF(nn.Cell):
|
||||
"""
|
||||
Test class: log survival function of categorical distributions.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(LogSF, self).__init__()
|
||||
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
||||
|
||||
def construct(self, x_):
|
||||
return self.c.log_survival(x_)
|
||||
|
||||
def test_log_survival():
|
||||
"""
|
||||
Test log survival funciton.
|
||||
"""
|
||||
expect_logsurvival = np.log([1., 0.3, 0.3, 0.3, 0.3])
|
||||
x_ = Tensor(np.array([-0.1, 0, 0, 0.5, 0.5]).astype(np.float32), dtype=dtype.float32)
|
||||
log_sf = LogSF()
|
||||
output = log_sf(x_)
|
||||
tol = 1e-6
|
||||
assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all()
|
||||
|
||||
class EntropyH(nn.Cell):
|
||||
"""
|
||||
Test class: entropy of categorical distributions.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(EntropyH, self).__init__()
|
||||
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
||||
|
||||
def construct(self):
|
||||
return self.c.entropy()
|
||||
|
||||
def test_entropy():
|
||||
"""
|
||||
Test entropy.
|
||||
"""
|
||||
cat_benchmark = stats.multinomial(n=1, p=[0.7, 0.3])
|
||||
expect_entropy = cat_benchmark.entropy().astype(np.float32)
|
||||
entropy = EntropyH()
|
||||
output = entropy()
|
||||
tol = 1e-6
|
||||
assert (np.abs(output.asnumpy() - expect_entropy) < tol).all()
|
||||
|
||||
class CrossEntropy(nn.Cell):
|
||||
"""
|
||||
Test class: cross entropy between categorical distributions.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(CrossEntropy, self).__init__()
|
||||
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
||||
|
||||
def construct(self, x_):
|
||||
entropy = self.c.entropy()
|
||||
kl_loss = self.c.kl_loss('Categorical', x_)
|
||||
h_sum_kl = entropy + kl_loss
|
||||
cross_entropy = self.c.cross_entropy('Categorical', x_)
|
||||
return h_sum_kl - cross_entropy
|
||||
|
||||
def test_cross_entropy():
|
||||
"""
|
||||
Test cross_entropy.
|
||||
"""
|
||||
cross_entropy = CrossEntropy()
|
||||
prob = Tensor([0.7, 0.3], dtype=dtype.float32)
|
||||
diff = cross_entropy(prob)
|
||||
tol = 1e-6
|
||||
assert (np.abs(diff.asnumpy()) < tol).all()
|
|
@ -0,0 +1,249 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Test nn.probability.distribution.Categorical.
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.nn.probability.distribution as msd
|
||||
from mindspore import dtype
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
def test_arguments():
|
||||
"""
|
||||
Args passing during initialization.
|
||||
"""
|
||||
c = msd.Categorical()
|
||||
assert isinstance(c, msd.Distribution)
|
||||
c = msd.Categorical([0.1, 0.9], dtype=dtype.int32)
|
||||
assert isinstance(c, msd.Distribution)
|
||||
|
||||
|
||||
def test_type():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Categorical([0.1], dtype=dtype.bool_)
|
||||
|
||||
|
||||
def test_name():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Categorical([0.1], name=1.0)
|
||||
|
||||
|
||||
def test_seed():
|
||||
with pytest.raises(TypeError):
|
||||
msd.Categorical([0.1], seed='seed')
|
||||
|
||||
|
||||
def test_prob():
|
||||
"""
|
||||
Invalid probability.
|
||||
"""
|
||||
with pytest.raises(ValueError):
|
||||
msd.Categorical([-0.1], dtype=dtype.int32)
|
||||
with pytest.raises(ValueError):
|
||||
msd.Categorical([1.1], dtype=dtype.int32)
|
||||
with pytest.raises(ValueError):
|
||||
msd.Categorical([0.0], dtype=dtype.int32)
|
||||
with pytest.raises(ValueError):
|
||||
msd.Categorical([1.0], dtype=dtype.int32)
|
||||
|
||||
def test_categorical_sum():
|
||||
"""
|
||||
Invaild probabilities.
|
||||
"""
|
||||
with pytest.raises(ValueError):
|
||||
msd.Categorical([[0.1, 0.2], [0.4, 0.6]], dtype=dtype.int32)
|
||||
with pytest.raises(ValueError):
|
||||
msd.Categorical([[0.5, 0.7], [0.6, 0.6]], dtype=dtype.int32)
|
||||
|
||||
def rank():
|
||||
"""
|
||||
Rank dimenshion less than 1.
|
||||
"""
|
||||
with pytest.raises(ValueError):
|
||||
msd.Categorical(0.2, dtype=dtype.int32)
|
||||
with pytest.raises(ValueError):
|
||||
msd.Categorical(np.array(0.3).astype(np.float32), dtype=dtype.int32)
|
||||
with pytest.raises(ValueError):
|
||||
msd.Categorical(Tensor(np.array(0.3).astype(np.float32)), dtype=dtype.int32)
|
||||
|
||||
class CategoricalProb(nn.Cell):
|
||||
"""
|
||||
Categorical distribution: initialize with probs.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(CategoricalProb, self).__init__()
|
||||
self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
|
||||
|
||||
def construct(self, value):
|
||||
prob = self.c.prob(value)
|
||||
log_prob = self.c.log_prob(value)
|
||||
cdf = self.c.cdf(value)
|
||||
log_cdf = self.c.log_cdf(value)
|
||||
sf = self.c.survival_function(value)
|
||||
log_sf = self.c.log_survival(value)
|
||||
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
||||
|
||||
|
||||
def test_categorical_prob():
|
||||
"""
|
||||
Test probability functions: passing value through construct.
|
||||
"""
|
||||
net = CategoricalProb()
|
||||
value = Tensor([0, 1, 0, 1, 0], dtype=dtype.float32)
|
||||
ans = net(value)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
|
||||
class CategoricalProb1(nn.Cell):
|
||||
"""
|
||||
Categorical distribution: initialize without probs.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(CategoricalProb1, self).__init__()
|
||||
self.c = msd.Categorical(dtype=dtype.int32)
|
||||
|
||||
def construct(self, value, probs):
|
||||
prob = self.c.prob(value, probs)
|
||||
log_prob = self.c.log_prob(value, probs)
|
||||
cdf = self.c.cdf(value, probs)
|
||||
log_cdf = self.c.log_cdf(value, probs)
|
||||
sf = self.c.survival_function(value, probs)
|
||||
log_sf = self.c.log_survival(value, probs)
|
||||
return prob + log_prob + cdf + log_cdf + sf + log_sf
|
||||
|
||||
|
||||
def test_categorical_prob1():
|
||||
"""
|
||||
Test probability functions: passing value/probs through construct.
|
||||
"""
|
||||
net = CategoricalProb1()
|
||||
value = Tensor([0, 1, 0, 1, 0], dtype=dtype.float32)
|
||||
probs = Tensor([0.3, 0.7], dtype=dtype.float32)
|
||||
ans = net(value, probs)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
|
||||
class CategoricalKl(nn.Cell):
|
||||
"""
|
||||
Test class: kl_loss between Categorical distributions.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(CategoricalKl, self).__init__()
|
||||
self.c1 = msd.Categorical([0.2, 0.2, 0.6], dtype=dtype.int32)
|
||||
self.c2 = msd.Categorical(dtype=dtype.int32)
|
||||
|
||||
def construct(self, probs_b, probs_a):
|
||||
kl1 = self.c1.kl_loss('Categorical', probs_b)
|
||||
kl2 = self.c2.kl_loss('Categorical', probs_b, probs_a)
|
||||
return kl1 + kl2
|
||||
|
||||
|
||||
def test_kl():
|
||||
"""
|
||||
Test kl_loss function.
|
||||
"""
|
||||
ber_net = CategoricalKl()
|
||||
probs_b = Tensor([0.3, 0.1, 0.6], dtype=dtype.float32)
|
||||
probs_a = Tensor([0.7, 0.2, 0.1], dtype=dtype.float32)
|
||||
ans = ber_net(probs_b, probs_a)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
|
||||
class CategoricalCrossEntropy(nn.Cell):
|
||||
"""
|
||||
Test class: cross_entropy of Categorical distribution.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(CategoricalCrossEntropy, self).__init__()
|
||||
self.c1 = msd.Categorical([0.1, 0.7, 0.2], dtype=dtype.int32)
|
||||
self.c2 = msd.Categorical(dtype=dtype.int32)
|
||||
|
||||
def construct(self, probs_b, probs_a):
|
||||
h1 = self.c1.cross_entropy('Categorical', probs_b)
|
||||
h2 = self.c2.cross_entropy('Categorical', probs_b, probs_a)
|
||||
return h1 + h2
|
||||
|
||||
|
||||
def test_cross_entropy():
|
||||
"""
|
||||
Test cross_entropy between Categorical distributions.
|
||||
"""
|
||||
net = CategoricalCrossEntropy()
|
||||
probs_b = Tensor([0.3, 0.1, 0.6], dtype=dtype.float32)
|
||||
probs_a = Tensor([0.7, 0.2, 0.1], dtype=dtype.float32)
|
||||
ans = net(probs_b, probs_a)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
|
||||
class CategoricalConstruct(nn.Cell):
|
||||
"""
|
||||
Categorical distribution: going through construct.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(CategoricalConstruct, self).__init__()
|
||||
self.c = msd.Categorical([0.1, 0.8, 0.1], dtype=dtype.int32)
|
||||
self.c1 = msd.Categorical(dtype=dtype.int32)
|
||||
|
||||
def construct(self, value, probs):
|
||||
prob = self.c('prob', value)
|
||||
prob1 = self.c('prob', value, probs)
|
||||
prob2 = self.c1('prob', value, probs)
|
||||
return prob + prob1 + prob2
|
||||
|
||||
def test_categorical_construct():
|
||||
"""
|
||||
Test probability function going through construct.
|
||||
"""
|
||||
net = CategoricalConstruct()
|
||||
value = Tensor([0, 1, 2, 0, 0], dtype=dtype.float32)
|
||||
probs = Tensor([0.5, 0.4, 0.1], dtype=dtype.float32)
|
||||
ans = net(value, probs)
|
||||
assert isinstance(ans, Tensor)
|
||||
|
||||
|
||||
class CategoricalBasics(nn.Cell):
|
||||
"""
|
||||
Test class: basic mean/var/mode/entropy function.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(CategoricalBasics, self).__init__()
|
||||
self.c = msd.Categorical([0.2, 0.7, 0.1], dtype=dtype.int32)
|
||||
self.c1 = msd.Categorical(dtype=dtype.int32)
|
||||
|
||||
def construct(self, probs):
|
||||
basics1 = self.c.mean() + self.c.var() + self.c.mode() + self.c.entropy()
|
||||
basics2 = self.c1.mean(probs) + self.c1.var(probs) +\
|
||||
self.c1.mode(probs) + self.c1.entropy(probs)
|
||||
return basics1 + basics2
|
||||
|
||||
|
||||
def test_basics():
|
||||
"""
|
||||
Test basics functionality of Categorical distribution.
|
||||
"""
|
||||
net = CategoricalBasics()
|
||||
probs = Tensor([0.7, 0.2, 0.1], dtype=dtype.float32)
|
||||
ans = net(probs)
|
||||
assert isinstance(ans, Tensor)
|
Loading…
Reference in New Issue