From c046874b0389839edf911af618f3d7dabbf30d5b Mon Sep 17 00:00:00 2001 From: zhouyuanshen Date: Mon, 27 Apr 2020 10:08:09 +0800 Subject: [PATCH] fix bug in infer_dtype function of hcom operations --- mindspore/ops/operations/comm_ops.py | 17 ++++++++++------- tests/ut/python/communication/test_comm.py | 2 +- tests/ut/python/parallel/test_bool_grad.py | 5 +++-- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index 969091de971..5fb5f3ed952 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -45,7 +45,6 @@ class AllReduce(PrimitiveWithInfer): Note: The operation of AllReduce does not support "prod" currently. - The input of AllReduce does not support dtype "Bool". Tensor must have same shape and format in all processes participating in the collective. Args: @@ -103,7 +102,7 @@ class AllReduce(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - if x_dtype == mstype.bool_: + if x_dtype.element_type() == mstype.bool_: raise TypeError("AllReduce does not support 'Bool' as the dtype of input!") return x_dtype @@ -161,7 +160,7 @@ class AllGather(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - if x_dtype == mstype.bool_: + if x_dtype.element_type() == mstype.bool_: raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") return x_dtype @@ -176,6 +175,7 @@ class ReduceScatter(PrimitiveWithInfer): Note: The back propagation of the op is not surported yet. Stay tuned for more. Tensor must have the same shape and format in all processes participating in the collective. + Args: op (str): Specifies an operation used for element-wise reductions, like sum, max, avg. Default: ReduceOp.SUM. @@ -218,7 +218,7 @@ class ReduceScatter(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - if x_dtype == mstype.bool_: + if x_dtype.element_type() == mstype.bool_: raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") return x_dtype @@ -275,8 +275,11 @@ class Broadcast(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - if x_dtype == mstype.bool_: - raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") + if not isinstance(x_dtype, tuple): + raise TypeError(f"{self.name}'s input should be a tuple!") + for _ele in x_dtype: + if _ele.element_type() == mstype.bool_: + raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") return x_dtype @@ -318,7 +321,7 @@ class _AlltoAll(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - if x_dtype == mstype.bool_: + if x_dtype.element_type() == mstype.bool_: raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") return x_dtype diff --git a/tests/ut/python/communication/test_comm.py b/tests/ut/python/communication/test_comm.py index 38fd7199fd4..885c8fa9e34 100644 --- a/tests/ut/python/communication/test_comm.py +++ b/tests/ut/python/communication/test_comm.py @@ -55,7 +55,7 @@ class BroadCastNet(nn.Cell): self.broadcast = Broadcast(0) def construct(self, x): - x = self.broadcast((x)) + x, = self.broadcast((x,)) x = self.dense(x) return x diff --git a/tests/ut/python/parallel/test_bool_grad.py b/tests/ut/python/parallel/test_bool_grad.py index f3cdfc80304..491707103b8 100644 --- a/tests/ut/python/parallel/test_bool_grad.py +++ b/tests/ut/python/parallel/test_bool_grad.py @@ -52,7 +52,7 @@ class CommonNet(nn.Cell): def __init__(self): super(CommonNet, self).__init__() self.weight = Parameter(Tensor(np.ones([256, 64]), dtype=ms.float32), name="mul_weight") - self.logicalnot = P.LogicalNot().set_strategy(((4,1),)) + self.logicalnot = P.LogicalNot().set_strategy(((4,2),)) self.equal = P.Equal().set_strategy(((4,2),(4,2))) def construct(self, x, label): @@ -78,4 +78,5 @@ def common_net(): def test_bool_grad(): - common_net() \ No newline at end of file + common_net() +