!47711 ascend categorical

Merge pull request !47711 from chenweifeng/ascend-categorical
This commit is contained in:
i-robot 2023-01-10 13:13:53 +00:00 committed by Gitee
commit 72f21eb213
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 1 additions and 3 deletions

View File

@ -25,7 +25,7 @@ import mindspore.nn as nn
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import check_prob, check_sum_equal_one, check_rank,\
check_distribution_name, raise_not_implemented_util
check_distribution_name
from ._utils.custom_ops import exp_generic, log_generic, broadcast_to
@ -405,8 +405,6 @@ class Categorical(Distribution):
Returns:
Tensor, shape is shape(probs)[:-1] + sample_shape
"""
if self.device_target == 'Ascend':
raise_not_implemented_util('On d backend, sample', self.name)
shape = self.checktuple(shape, 'shape')
probs = self._check_param_type(probs)
num_classes = self.shape(probs)[-1]