forked from mindspore-Ecosystem/mindspore
!31849 fix tril && triu to support bool
Merge pull request !31849 from zhuzhongrui/pub_master2
This commit is contained in:
commit
62e4589bb1
|
@ -1242,7 +1242,7 @@ def tril(m, k=0):
|
|||
m = m.astype(mstype.float32)
|
||||
assist = nn_tril(m.shape, mstype.float32, k)
|
||||
# MindSpore binary op do not support bool
|
||||
elif dtype == mstype.Bool:
|
||||
elif dtype == mstype.bool_:
|
||||
m = m.astype(mstype.float32)
|
||||
assist = nn_tril(m.shape, mstype.float32, k)
|
||||
else:
|
||||
|
@ -1288,6 +1288,10 @@ def triu(m, k=0):
|
|||
if device_target == "Ascend":
|
||||
m = m.astype(mstype.float32)
|
||||
assist = nn_triu(m.shape, mstype.float32, k)
|
||||
# MindSpore binary op do not support bool
|
||||
elif dtype == mstype.bool_:
|
||||
m = m.astype(mstype.float32)
|
||||
assist = nn_triu(m.shape, mstype.float32, k)
|
||||
else:
|
||||
assist = nn_triu(m.shape, dtype, k)
|
||||
return F.tensor_mul(assist, m).astype(dtype)
|
||||
|
|
Loading…
Reference in New Issue