forked from mindspore-Ecosystem/mindspore
回退 'Pull Request !231 : add bool type check in communication operator '
This commit is contained in:
parent
5141054ecd
commit
a2850cae32
|
@ -162,8 +162,6 @@ class AllGather(PrimitiveWithInfer):
|
||||||
return x_shape
|
return x_shape
|
||||||
|
|
||||||
def infer_dtype(self, x_dtype):
|
def infer_dtype(self, x_dtype):
|
||||||
if x_dtype == mstype.bool_:
|
|
||||||
raise TypeError("AllGather does not support 'Bool' as the dtype of input!")
|
|
||||||
return x_dtype
|
return x_dtype
|
||||||
|
|
||||||
def __call__(self, tensor):
|
def __call__(self, tensor):
|
||||||
|
@ -221,8 +219,6 @@ class ReduceScatter(PrimitiveWithInfer):
|
||||||
return x_shape
|
return x_shape
|
||||||
|
|
||||||
def infer_dtype(self, x_dtype):
|
def infer_dtype(self, x_dtype):
|
||||||
if x_dtype == mstype.bool_:
|
|
||||||
raise TypeError("ReduceScatter does not support 'Bool' as the dtype of input!")
|
|
||||||
return x_dtype
|
return x_dtype
|
||||||
|
|
||||||
def __call__(self, tensor):
|
def __call__(self, tensor):
|
||||||
|
@ -280,8 +276,6 @@ class Broadcast(PrimitiveWithInfer):
|
||||||
return x_shape
|
return x_shape
|
||||||
|
|
||||||
def infer_dtype(self, x_dtype):
|
def infer_dtype(self, x_dtype):
|
||||||
if x_dtype == mstype.bool_:
|
|
||||||
raise TypeError("Broadcast does not support 'Bool' as the dtype of input!")
|
|
||||||
return x_dtype
|
return x_dtype
|
||||||
|
|
||||||
|
|
||||||
|
@ -324,8 +318,6 @@ class _AlltoAll(PrimitiveWithInfer):
|
||||||
return x_shape
|
return x_shape
|
||||||
|
|
||||||
def infer_dtype(self, x_dtype):
|
def infer_dtype(self, x_dtype):
|
||||||
if x_dtype == mstype.bool_:
|
|
||||||
raise TypeError("AlltoAll does not support 'Bool' as the dtype of input!")
|
|
||||||
return x_dtype
|
return x_dtype
|
||||||
|
|
||||||
def __call__(self, tensor):
|
def __call__(self, tensor):
|
||||||
|
|
Loading…
Reference in New Issue