!231 add bool type check in communication operator

Merge pull request !231 from chentingting/add_bool_type_check_in_comm_op
This commit is contained in:
mindspore-ci-bot 2020-04-11 14:17:49 +08:00 committed by Gitee
commit 54481c30c8
1 changed files with 8 additions and 0 deletions

View File

@ -162,6 +162,8 @@ class AllGather(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_:
raise TypeError("AllGather does not support 'Bool' as the dtype of input!")
return x_dtype
def __call__(self, tensor):
@ -219,6 +221,8 @@ class ReduceScatter(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_:
raise TypeError("ReduceScatter does not support 'Bool' as the dtype of input!")
return x_dtype
def __call__(self, tensor):
@ -276,6 +280,8 @@ class Broadcast(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_:
raise TypeError("Broadcast does not support 'Bool' as the dtype of input!")
return x_dtype
@ -318,6 +324,8 @@ class _AlltoAll(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_:
raise TypeError("AlltoAll does not support 'Bool' as the dtype of input!")
return x_dtype
def __call__(self, tensor):