forked from mindspore-Ecosystem/mindspore
fix bug in infer_dtype function of hcom operations
This commit is contained in:
parent
e213f2a435
commit
c046874b03
|
@ -45,7 +45,6 @@ class AllReduce(PrimitiveWithInfer):
|
|||
|
||||
Note:
|
||||
The operation of AllReduce does not support "prod" currently.
|
||||
The input of AllReduce does not support dtype "Bool".
|
||||
Tensor must have same shape and format in all processes participating in the collective.
|
||||
|
||||
Args:
|
||||
|
@ -103,7 +102,7 @@ class AllReduce(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
if x_dtype == mstype.bool_:
|
||||
if x_dtype.element_type() == mstype.bool_:
|
||||
raise TypeError("AllReduce does not support 'Bool' as the dtype of input!")
|
||||
return x_dtype
|
||||
|
||||
|
@ -161,7 +160,7 @@ class AllGather(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
if x_dtype == mstype.bool_:
|
||||
if x_dtype.element_type() == mstype.bool_:
|
||||
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
||||
return x_dtype
|
||||
|
||||
|
@ -176,6 +175,7 @@ class ReduceScatter(PrimitiveWithInfer):
|
|||
Note:
|
||||
The back propagation of the op is not surported yet. Stay tuned for more.
|
||||
Tensor must have the same shape and format in all processes participating in the collective.
|
||||
|
||||
Args:
|
||||
op (str): Specifies an operation used for element-wise reductions,
|
||||
like sum, max, avg. Default: ReduceOp.SUM.
|
||||
|
@ -218,7 +218,7 @@ class ReduceScatter(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
if x_dtype == mstype.bool_:
|
||||
if x_dtype.element_type() == mstype.bool_:
|
||||
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
||||
return x_dtype
|
||||
|
||||
|
@ -275,8 +275,11 @@ class Broadcast(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
if x_dtype == mstype.bool_:
|
||||
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
||||
if not isinstance(x_dtype, tuple):
|
||||
raise TypeError(f"{self.name}'s input should be a tuple!")
|
||||
for _ele in x_dtype:
|
||||
if _ele.element_type() == mstype.bool_:
|
||||
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
@ -318,7 +321,7 @@ class _AlltoAll(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
if x_dtype == mstype.bool_:
|
||||
if x_dtype.element_type() == mstype.bool_:
|
||||
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
||||
return x_dtype
|
||||
|
||||
|
|
|
@ -55,7 +55,7 @@ class BroadCastNet(nn.Cell):
|
|||
self.broadcast = Broadcast(0)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.broadcast((x))
|
||||
x, = self.broadcast((x,))
|
||||
x = self.dense(x)
|
||||
return x
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ class CommonNet(nn.Cell):
|
|||
def __init__(self):
|
||||
super(CommonNet, self).__init__()
|
||||
self.weight = Parameter(Tensor(np.ones([256, 64]), dtype=ms.float32), name="mul_weight")
|
||||
self.logicalnot = P.LogicalNot().set_strategy(((4,1),))
|
||||
self.logicalnot = P.LogicalNot().set_strategy(((4,2),))
|
||||
self.equal = P.Equal().set_strategy(((4,2),(4,2)))
|
||||
|
||||
def construct(self, x, label):
|
||||
|
@ -78,4 +78,5 @@ def common_net():
|
|||
|
||||
|
||||
def test_bool_grad():
|
||||
common_net()
|
||||
common_net()
|
||||
|
||||
|
|
Loading…
Reference in New Issue