forked from mindspore-Ecosystem/mindspore
Fix bug of tril/triu and enumerate under pynative
This commit is contained in:
parent
3189868a15
commit
2608c17bee
|
@ -176,6 +176,8 @@ class Tensor(Tensor_):
|
|||
return out
|
||||
|
||||
def __getitem__(self, index):
|
||||
if isinstance(index, int) and index >= self.shape[0]:
|
||||
raise IndexError("index {} is out of bounds for axis 0 with size {}".format(index, self.shape[0]))
|
||||
out = tensor_operator_registry.get('__getitem__')(self, index)
|
||||
return out
|
||||
|
||||
|
@ -319,7 +321,7 @@ class Tensor(Tensor_):
|
|||
|
||||
Args:
|
||||
shape (Tensor): The input tensor. The shape of input tensor must obey
|
||||
the broadcasting rule.
|
||||
the broadcasting rule.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same dimension as input tensor.
|
||||
|
|
|
@ -725,7 +725,7 @@ class Tril(Cell):
|
|||
|
||||
def construct(self, x, k=0):
|
||||
assist = tril(x.shape, self.dtype(x), k)
|
||||
result = self.mul(self.cast(x, mstype.int32), self.cast(assist, mstype.int32))
|
||||
result = self.mul(self.cast(x, mstype.float32), self.cast(assist, mstype.float32))
|
||||
return self.cast(result, self.dtype(x))
|
||||
|
||||
|
||||
|
@ -767,7 +767,7 @@ class Triu(Cell):
|
|||
|
||||
def construct(self, x, k=0):
|
||||
assist = triu(x.shape, self.dtype(x), k)
|
||||
result = self.mul(self.cast(x, mstype.int32), self.cast(assist, mstype.int32))
|
||||
result = self.mul(self.cast(x, mstype.float32), self.cast(assist, mstype.float32))
|
||||
return self.cast(result, self.dtype(x))
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue