forked from mindspore-Ecosystem/mindspore
!1080 Fix dtype check in infer_dtype() function of comm_ops.py
Merge pull request !1080 from zhouyuanshen/master
This commit is contained in:
commit
331ca249ef
|
@ -39,6 +39,8 @@ class ReduceOp:
|
|||
PROD = "prod"
|
||||
|
||||
|
||||
target_dtypes = (mstype.int8, mstype.int32, mstype.float16, mstype.float32)
|
||||
|
||||
class AllReduce(PrimitiveWithInfer):
|
||||
"""
|
||||
Reduces the tensor data across all devices in such a way that all devices will get the same final result.
|
||||
|
@ -102,8 +104,7 @@ class AllReduce(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
if x_dtype.element_type() == mstype.bool_:
|
||||
raise TypeError("AllReduce does not support 'Bool' as the dtype of input!")
|
||||
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
@ -161,8 +162,7 @@ class AllGather(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
if x_dtype.element_type() == mstype.bool_:
|
||||
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
||||
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
|
||||
return x_dtype
|
||||
|
||||
def __call__(self, tensor):
|
||||
|
@ -219,8 +219,7 @@ class ReduceScatter(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
if x_dtype.element_type() == mstype.bool_:
|
||||
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
||||
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
|
||||
return x_dtype
|
||||
|
||||
def __call__(self, tensor):
|
||||
|
@ -279,8 +278,7 @@ class Broadcast(PrimitiveWithInfer):
|
|||
if not isinstance(x_dtype, tuple):
|
||||
raise TypeError(f"{self.name}'s input should be a tuple!")
|
||||
for _ele in x_dtype:
|
||||
if _ele.element_type() == mstype.bool_:
|
||||
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
||||
validator.check_tensor_type_same({'x': _ele}, target_dtypes, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
@ -322,8 +320,7 @@ class _AlltoAll(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
if x_dtype.element_type() == mstype.bool_:
|
||||
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
||||
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
|
||||
return x_dtype
|
||||
|
||||
def __call__(self, tensor):
|
||||
|
|
Loading…
Reference in New Issue