forked from mindspore-Ecosystem/mindspore
!9630 Fix bugs for Tril/Triu and let enumerate run in pynative
From: @liangzhibo Reviewed-by: @ginfung,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
2403618aaa
|
@ -177,6 +177,8 @@ class Tensor(Tensor_):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
|
if isinstance(index, int) and index >= self.shape[0]:
|
||||||
|
raise IndexError("index {} is out of bounds for axis 0 with size {}".format(index, self.shape[0]))
|
||||||
out = tensor_operator_registry.get('__getitem__')(self, index)
|
out = tensor_operator_registry.get('__getitem__')(self, index)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -318,7 +320,7 @@ class Tensor(Tensor_):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
shape (Tensor): The input tensor. The shape of input tensor must obey
|
shape (Tensor): The input tensor. The shape of input tensor must obey
|
||||||
the broadcasting rule.
|
the broadcasting rule.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor, has the same dimension as input tensor.
|
Tensor, has the same dimension as input tensor.
|
||||||
|
|
|
@ -775,7 +775,7 @@ class Tril(Cell):
|
||||||
|
|
||||||
def construct(self, x, k=0):
|
def construct(self, x, k=0):
|
||||||
assist = tril(x.shape, self.dtype(x), k)
|
assist = tril(x.shape, self.dtype(x), k)
|
||||||
result = self.mul(self.cast(x, mstype.int32), self.cast(assist, mstype.int32))
|
result = self.mul(self.cast(x, mstype.float32), self.cast(assist, mstype.float32))
|
||||||
return self.cast(result, self.dtype(x))
|
return self.cast(result, self.dtype(x))
|
||||||
|
|
||||||
|
|
||||||
|
@ -817,7 +817,7 @@ class Triu(Cell):
|
||||||
|
|
||||||
def construct(self, x, k=0):
|
def construct(self, x, k=0):
|
||||||
assist = triu(x.shape, self.dtype(x), k)
|
assist = triu(x.shape, self.dtype(x), k)
|
||||||
result = self.mul(self.cast(x, mstype.int32), self.cast(assist, mstype.int32))
|
result = self.mul(self.cast(x, mstype.float32), self.cast(assist, mstype.float32))
|
||||||
return self.cast(result, self.dtype(x))
|
return self.cast(result, self.dtype(x))
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue