fix bug in infer_dtype function of hcom operations

This commit is contained in:
zhouyuanshen 2020-04-27 10:08:09 +08:00
parent e213f2a435
commit c046874b03
3 changed files with 14 additions and 10 deletions

View File

@ -45,7 +45,6 @@ class AllReduce(PrimitiveWithInfer):
Note: Note:
The operation of AllReduce does not support "prod" currently. The operation of AllReduce does not support "prod" currently.
The input of AllReduce does not support dtype "Bool".
Tensor must have same shape and format in all processes participating in the collective. Tensor must have same shape and format in all processes participating in the collective.
Args: Args:
@ -103,7 +102,7 @@ class AllReduce(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_: if x_dtype.element_type() == mstype.bool_:
raise TypeError("AllReduce does not support 'Bool' as the dtype of input!") raise TypeError("AllReduce does not support 'Bool' as the dtype of input!")
return x_dtype return x_dtype
@ -161,7 +160,7 @@ class AllGather(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_: if x_dtype.element_type() == mstype.bool_:
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
return x_dtype return x_dtype
@ -176,6 +175,7 @@ class ReduceScatter(PrimitiveWithInfer):
Note: Note:
The back propagation of the op is not surported yet. Stay tuned for more. The back propagation of the op is not surported yet. Stay tuned for more.
Tensor must have the same shape and format in all processes participating in the collective. Tensor must have the same shape and format in all processes participating in the collective.
Args: Args:
op (str): Specifies an operation used for element-wise reductions, op (str): Specifies an operation used for element-wise reductions,
like sum, max, avg. Default: ReduceOp.SUM. like sum, max, avg. Default: ReduceOp.SUM.
@ -218,7 +218,7 @@ class ReduceScatter(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_: if x_dtype.element_type() == mstype.bool_:
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
return x_dtype return x_dtype
@ -275,7 +275,10 @@ class Broadcast(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_: if not isinstance(x_dtype, tuple):
raise TypeError(f"{self.name}'s input should be a tuple!")
for _ele in x_dtype:
if _ele.element_type() == mstype.bool_:
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
return x_dtype return x_dtype
@ -318,7 +321,7 @@ class _AlltoAll(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
if x_dtype == mstype.bool_: if x_dtype.element_type() == mstype.bool_:
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
return x_dtype return x_dtype

View File

@ -55,7 +55,7 @@ class BroadCastNet(nn.Cell):
self.broadcast = Broadcast(0) self.broadcast = Broadcast(0)
def construct(self, x): def construct(self, x):
x = self.broadcast((x)) x, = self.broadcast((x,))
x = self.dense(x) x = self.dense(x)
return x return x

View File

@ -52,7 +52,7 @@ class CommonNet(nn.Cell):
def __init__(self): def __init__(self):
super(CommonNet, self).__init__() super(CommonNet, self).__init__()
self.weight = Parameter(Tensor(np.ones([256, 64]), dtype=ms.float32), name="mul_weight") self.weight = Parameter(Tensor(np.ones([256, 64]), dtype=ms.float32), name="mul_weight")
self.logicalnot = P.LogicalNot().set_strategy(((4,1),)) self.logicalnot = P.LogicalNot().set_strategy(((4,2),))
self.equal = P.Equal().set_strategy(((4,2),(4,2))) self.equal = P.Equal().set_strategy(((4,2),(4,2)))
def construct(self, x, label): def construct(self, x, label):
@ -79,3 +79,4 @@ def common_net():
def test_bool_grad(): def test_bool_grad():
common_net() common_net()