forked from mindspore-Ecosystem/mindspore
fix bug in infer_dtype function of hcom operations
This commit is contained in:
parent
e213f2a435
commit
c046874b03
|
@ -45,7 +45,6 @@ class AllReduce(PrimitiveWithInfer):
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
The operation of AllReduce does not support "prod" currently.
|
The operation of AllReduce does not support "prod" currently.
|
||||||
The input of AllReduce does not support dtype "Bool".
|
|
||||||
Tensor must have same shape and format in all processes participating in the collective.
|
Tensor must have same shape and format in all processes participating in the collective.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -103,7 +102,7 @@ class AllReduce(PrimitiveWithInfer):
|
||||||
return x_shape
|
return x_shape
|
||||||
|
|
||||||
def infer_dtype(self, x_dtype):
|
def infer_dtype(self, x_dtype):
|
||||||
if x_dtype == mstype.bool_:
|
if x_dtype.element_type() == mstype.bool_:
|
||||||
raise TypeError("AllReduce does not support 'Bool' as the dtype of input!")
|
raise TypeError("AllReduce does not support 'Bool' as the dtype of input!")
|
||||||
return x_dtype
|
return x_dtype
|
||||||
|
|
||||||
|
@ -161,7 +160,7 @@ class AllGather(PrimitiveWithInfer):
|
||||||
return x_shape
|
return x_shape
|
||||||
|
|
||||||
def infer_dtype(self, x_dtype):
|
def infer_dtype(self, x_dtype):
|
||||||
if x_dtype == mstype.bool_:
|
if x_dtype.element_type() == mstype.bool_:
|
||||||
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
||||||
return x_dtype
|
return x_dtype
|
||||||
|
|
||||||
|
@ -176,6 +175,7 @@ class ReduceScatter(PrimitiveWithInfer):
|
||||||
Note:
|
Note:
|
||||||
The back propagation of the op is not surported yet. Stay tuned for more.
|
The back propagation of the op is not surported yet. Stay tuned for more.
|
||||||
Tensor must have the same shape and format in all processes participating in the collective.
|
Tensor must have the same shape and format in all processes participating in the collective.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
op (str): Specifies an operation used for element-wise reductions,
|
op (str): Specifies an operation used for element-wise reductions,
|
||||||
like sum, max, avg. Default: ReduceOp.SUM.
|
like sum, max, avg. Default: ReduceOp.SUM.
|
||||||
|
@ -218,7 +218,7 @@ class ReduceScatter(PrimitiveWithInfer):
|
||||||
return x_shape
|
return x_shape
|
||||||
|
|
||||||
def infer_dtype(self, x_dtype):
|
def infer_dtype(self, x_dtype):
|
||||||
if x_dtype == mstype.bool_:
|
if x_dtype.element_type() == mstype.bool_:
|
||||||
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
||||||
return x_dtype
|
return x_dtype
|
||||||
|
|
||||||
|
@ -275,8 +275,11 @@ class Broadcast(PrimitiveWithInfer):
|
||||||
return x_shape
|
return x_shape
|
||||||
|
|
||||||
def infer_dtype(self, x_dtype):
|
def infer_dtype(self, x_dtype):
|
||||||
if x_dtype == mstype.bool_:
|
if not isinstance(x_dtype, tuple):
|
||||||
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
raise TypeError(f"{self.name}'s input should be a tuple!")
|
||||||
|
for _ele in x_dtype:
|
||||||
|
if _ele.element_type() == mstype.bool_:
|
||||||
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
||||||
return x_dtype
|
return x_dtype
|
||||||
|
|
||||||
|
|
||||||
|
@ -318,7 +321,7 @@ class _AlltoAll(PrimitiveWithInfer):
|
||||||
return x_shape
|
return x_shape
|
||||||
|
|
||||||
def infer_dtype(self, x_dtype):
|
def infer_dtype(self, x_dtype):
|
||||||
if x_dtype == mstype.bool_:
|
if x_dtype.element_type() == mstype.bool_:
|
||||||
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!")
|
||||||
return x_dtype
|
return x_dtype
|
||||||
|
|
||||||
|
|
|
@ -55,7 +55,7 @@ class BroadCastNet(nn.Cell):
|
||||||
self.broadcast = Broadcast(0)
|
self.broadcast = Broadcast(0)
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
x = self.broadcast((x))
|
x, = self.broadcast((x,))
|
||||||
x = self.dense(x)
|
x = self.dense(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
|
@ -52,7 +52,7 @@ class CommonNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(CommonNet, self).__init__()
|
super(CommonNet, self).__init__()
|
||||||
self.weight = Parameter(Tensor(np.ones([256, 64]), dtype=ms.float32), name="mul_weight")
|
self.weight = Parameter(Tensor(np.ones([256, 64]), dtype=ms.float32), name="mul_weight")
|
||||||
self.logicalnot = P.LogicalNot().set_strategy(((4,1),))
|
self.logicalnot = P.LogicalNot().set_strategy(((4,2),))
|
||||||
self.equal = P.Equal().set_strategy(((4,2),(4,2)))
|
self.equal = P.Equal().set_strategy(((4,2),(4,2)))
|
||||||
|
|
||||||
def construct(self, x, label):
|
def construct(self, x, label):
|
||||||
|
@ -78,4 +78,5 @@ def common_net():
|
||||||
|
|
||||||
|
|
||||||
def test_bool_grad():
|
def test_bool_grad():
|
||||||
common_net()
|
common_net()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue