Fix bug of tril/triu and enumerate under pynative

This commit is contained in:
l00591931 2020-12-08 11:18:04 +08:00
parent 3189868a15
commit 2608c17bee
2 changed files with 5 additions and 3 deletions

View File

@ -176,6 +176,8 @@ class Tensor(Tensor_):
return out
def __getitem__(self, index):
if isinstance(index, int) and index >= self.shape[0]:
raise IndexError("index {} is out of bounds for axis 0 with size {}".format(index, self.shape[0]))
out = tensor_operator_registry.get('__getitem__')(self, index)
return out
@ -319,7 +321,7 @@ class Tensor(Tensor_):
Args:
shape (Tensor): The input tensor. The shape of input tensor must obey
the broadcasting rule.
the broadcasting rule.
Returns:
Tensor, has the same dimension as input tensor.

View File

@ -725,7 +725,7 @@ class Tril(Cell):
def construct(self, x, k=0):
assist = tril(x.shape, self.dtype(x), k)
result = self.mul(self.cast(x, mstype.int32), self.cast(assist, mstype.int32))
result = self.mul(self.cast(x, mstype.float32), self.cast(assist, mstype.float32))
return self.cast(result, self.dtype(x))
@ -767,7 +767,7 @@ class Triu(Cell):
def construct(self, x, k=0):
assist = triu(x.shape, self.dtype(x), k)
result = self.mul(self.cast(x, mstype.int32), self.cast(assist, mstype.int32))
result = self.mul(self.cast(x, mstype.float32), self.cast(assist, mstype.float32))
return self.cast(result, self.dtype(x))