!9630 Fix bugs for Tril/Triu and let enumerate run in pynative

From: @liangzhibo
Reviewed-by: @ginfung,@zh_qh
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2020-12-08 19:13:57 +08:00 committed by Gitee
commit 2403618aaa
2 changed files with 5 additions and 3 deletions

View File

@ -177,6 +177,8 @@ class Tensor(Tensor_):
return out return out
def __getitem__(self, index): def __getitem__(self, index):
if isinstance(index, int) and index >= self.shape[0]:
raise IndexError("index {} is out of bounds for axis 0 with size {}".format(index, self.shape[0]))
out = tensor_operator_registry.get('__getitem__')(self, index) out = tensor_operator_registry.get('__getitem__')(self, index)
return out return out
@ -318,7 +320,7 @@ class Tensor(Tensor_):
Args: Args:
shape (Tensor): The input tensor. The shape of input tensor must obey shape (Tensor): The input tensor. The shape of input tensor must obey
the broadcasting rule. the broadcasting rule.
Returns: Returns:
Tensor, has the same dimension as input tensor. Tensor, has the same dimension as input tensor.

View File

@ -775,7 +775,7 @@ class Tril(Cell):
def construct(self, x, k=0): def construct(self, x, k=0):
assist = tril(x.shape, self.dtype(x), k) assist = tril(x.shape, self.dtype(x), k)
result = self.mul(self.cast(x, mstype.int32), self.cast(assist, mstype.int32)) result = self.mul(self.cast(x, mstype.float32), self.cast(assist, mstype.float32))
return self.cast(result, self.dtype(x)) return self.cast(result, self.dtype(x))
@ -817,7 +817,7 @@ class Triu(Cell):
def construct(self, x, k=0): def construct(self, x, k=0):
assist = triu(x.shape, self.dtype(x), k) assist = triu(x.shape, self.dtype(x), k)
result = self.mul(self.cast(x, mstype.int32), self.cast(assist, mstype.int32)) result = self.mul(self.cast(x, mstype.float32), self.cast(assist, mstype.float32))
return self.cast(result, self.dtype(x)) return self.cast(result, self.dtype(x))