!31849 fix tril && triu to support bool

Merge pull request !31849 from zhuzhongrui/pub_master2
This commit is contained in:
i-robot 2022-03-24 15:52:51 +00:00 committed by Gitee
commit 62e4589bb1
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 5 additions and 1 deletions

View File

@ -1242,7 +1242,7 @@ def tril(m, k=0):
m = m.astype(mstype.float32)
assist = nn_tril(m.shape, mstype.float32, k)
# MindSpore binary op do not support bool
elif dtype == mstype.Bool:
elif dtype == mstype.bool_:
m = m.astype(mstype.float32)
assist = nn_tril(m.shape, mstype.float32, k)
else:
@ -1288,6 +1288,10 @@ def triu(m, k=0):
if device_target == "Ascend":
m = m.astype(mstype.float32)
assist = nn_triu(m.shape, mstype.float32, k)
# MindSpore binary op do not support bool
elif dtype == mstype.bool_:
m = m.astype(mstype.float32)
assist = nn_triu(m.shape, mstype.float32, k)
else:
assist = nn_triu(m.shape, dtype, k)
return F.tensor_mul(assist, m).astype(dtype)