forked from mindspore-Ecosystem/mindspore
!47711 ascend categorical
Merge pull request !47711 from chenweifeng/ascend-categorical
This commit is contained in:
commit
72f21eb213
|
@ -25,7 +25,7 @@ import mindspore.nn as nn
|
|||
from mindspore.common import dtype as mstype
|
||||
from .distribution import Distribution
|
||||
from ._utils.utils import check_prob, check_sum_equal_one, check_rank,\
|
||||
check_distribution_name, raise_not_implemented_util
|
||||
check_distribution_name
|
||||
from ._utils.custom_ops import exp_generic, log_generic, broadcast_to
|
||||
|
||||
|
||||
|
@ -405,8 +405,6 @@ class Categorical(Distribution):
|
|||
Returns:
|
||||
Tensor, shape is shape(probs)[:-1] + sample_shape
|
||||
"""
|
||||
if self.device_target == 'Ascend':
|
||||
raise_not_implemented_util('On d backend, sample', self.name)
|
||||
shape = self.checktuple(shape, 'shape')
|
||||
probs = self._check_param_type(probs)
|
||||
num_classes = self.shape(probs)[-1]
|
||||
|
|
Loading…
Reference in New Issue