asccend-categorical

This commit is contained in:
wilfChen 2023-01-10 16:11:59 +08:00
parent 24b3ad9ef0
commit 94b65d60ae
1 changed files with 1 additions and 3 deletions

View File

@ -25,7 +25,7 @@ import mindspore.nn as nn
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import check_prob, check_sum_equal_one, check_rank,\
check_distribution_name, raise_not_implemented_util
check_distribution_name
from ._utils.custom_ops import exp_generic, log_generic, broadcast_to
@ -405,8 +405,6 @@ class Categorical(Distribution):
Returns:
Tensor, shape is shape(probs)[:-1] + sample_shape
"""
if self.device_target == 'Ascend':
raise_not_implemented_util('On d backend, sample', self.name)
shape = self.checktuple(shape, 'shape')
probs = self._check_param_type(probs)
num_classes = self.shape(probs)[-1]