forked from mindspore-Ecosystem/mindspore
!231 add bool type check in communication operator
Merge pull request !231 from chentingting/add_bool_type_check_in_comm_op
This commit is contained in:
commit
54481c30c8
|
@ -162,6 +162,8 @@ class AllGather(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
if x_dtype == mstype.bool_:
|
||||
raise TypeError("AllGather does not support 'Bool' as the dtype of input!")
|
||||
return x_dtype
|
||||
|
||||
def __call__(self, tensor):
|
||||
|
@ -219,6 +221,8 @@ class ReduceScatter(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
if x_dtype == mstype.bool_:
|
||||
raise TypeError("ReduceScatter does not support 'Bool' as the dtype of input!")
|
||||
return x_dtype
|
||||
|
||||
def __call__(self, tensor):
|
||||
|
@ -276,6 +280,8 @@ class Broadcast(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
if x_dtype == mstype.bool_:
|
||||
raise TypeError("Broadcast does not support 'Bool' as the dtype of input!")
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
@ -318,6 +324,8 @@ class _AlltoAll(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
if x_dtype == mstype.bool_:
|
||||
raise TypeError("AlltoAll does not support 'Bool' as the dtype of input!")
|
||||
return x_dtype
|
||||
|
||||
def __call__(self, tensor):
|
||||
|
|
Loading…
Reference in New Issue