forked from mindspore-Ecosystem/mindspore
!9630 Fix bugs for Tril/Triu and let enumerate run in pynative
From: @liangzhibo Reviewed-by: @ginfung,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
2403618aaa
|
@ -177,6 +177,8 @@ class Tensor(Tensor_):
|
|||
return out
|
||||
|
||||
def __getitem__(self, index):
|
||||
if isinstance(index, int) and index >= self.shape[0]:
|
||||
raise IndexError("index {} is out of bounds for axis 0 with size {}".format(index, self.shape[0]))
|
||||
out = tensor_operator_registry.get('__getitem__')(self, index)
|
||||
return out
|
||||
|
||||
|
@ -318,7 +320,7 @@ class Tensor(Tensor_):
|
|||
|
||||
Args:
|
||||
shape (Tensor): The input tensor. The shape of input tensor must obey
|
||||
the broadcasting rule.
|
||||
the broadcasting rule.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same dimension as input tensor.
|
||||
|
|
|
@ -775,7 +775,7 @@ class Tril(Cell):
|
|||
|
||||
def construct(self, x, k=0):
|
||||
assist = tril(x.shape, self.dtype(x), k)
|
||||
result = self.mul(self.cast(x, mstype.int32), self.cast(assist, mstype.int32))
|
||||
result = self.mul(self.cast(x, mstype.float32), self.cast(assist, mstype.float32))
|
||||
return self.cast(result, self.dtype(x))
|
||||
|
||||
|
||||
|
@ -817,7 +817,7 @@ class Triu(Cell):
|
|||
|
||||
def construct(self, x, k=0):
|
||||
assist = triu(x.shape, self.dtype(x), k)
|
||||
result = self.mul(self.cast(x, mstype.int32), self.cast(assist, mstype.int32))
|
||||
result = self.mul(self.cast(x, mstype.float32), self.cast(assist, mstype.float32))
|
||||
return self.cast(result, self.dtype(x))
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue