diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index cad57ad2015..8cb59773c3a 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -177,12 +177,26 @@ class Primitive(Primitive_): for in_ele in in_strategy: if not isinstance(in_ele, tuple): raise TypeError(f'The element of strategy must be tuple type, but got:{type(in_ele)}') + for in_value in in_ele: + if not isinstance(in_value, int): + raise TypeError(f'The in_strategy: {in_strategy} of {self.name} is not valid,' + f' the value of strategy must be int type, but got:{type(in_value)}') + if out_strategy is not None: if not isinstance(out_strategy, tuple): raise TypeError(f'out strategy must be tuple type, but got:{type(out_strategy)}') for out_ele in out_strategy: if not isinstance(out_ele, tuple): raise TypeError(f'The element of strategy must be tuple type, but got:{type(out_ele)}') + for out_value in out_ele: + if not isinstance(out_value, int): + raise TypeError(f'The in_strategy: {out_strategy} of {self.name} is not valid,' + f' the value of strategy must be int type, but got:{type(out_value)}') + + if in_strategy is None and out_strategy is not None: + raise ValueError(f'The out_strategy of {self.name} is {out_strategy}, need to set in_strategy,' + f' but got none') + if not _is_in_auto_parallel_mode(): if in_strategy is not None: logger.warning(f"The in_strategy: {in_strategy} of {self.name} is not valid in {mode}. " @@ -190,6 +204,7 @@ class Primitive(Primitive_): if out_strategy is not None: logger.warning(f"The out_strategy: {out_strategy} of {self.name} is not valid in {mode}. " f"Please use semi auto or auto parallel mode.") + self.add_prim_attr("in_strategy", in_strategy) self.add_prim_attr("out_strategy", out_strategy) return self diff --git a/tests/ut/python/parallel/test_two_matmul.py b/tests/ut/python/parallel/test_two_matmul.py index baf54201cf9..a76ebfbc79c 100644 --- a/tests/ut/python/parallel/test_two_matmul.py +++ b/tests/ut/python/parallel/test_two_matmul.py @@ -13,6 +13,7 @@ # limitations under the License. import numpy as np +import pytest import mindspore as ms import mindspore.nn as nn @@ -357,3 +358,81 @@ def test_matmul_output_strategy_all_reduce_transpose_repeat_calc(): y = Tensor(np.ones([64, 32]), dtype=ms.float32) b = Tensor(np.ones([128, 64]), dtype=ms.float32) compile_net(net, x, y, b) + + +def test_matmul_in_strategy_not_int(): + """ + Feature: the type of in_strategy's value is not int + Description: + Expectation: rasise TypeError + """ + class Net(nn.Cell): + def __init__(self, matmul_in_strategy, matmul_out_strategy, mul_strategy): + super().__init__() + self.matmul = P.MatMul(transpose_b=True).shard(matmul_in_strategy, matmul_out_strategy) + self.mul = P.Mul().shard(mul_strategy) + + def construct(self, x, y, b): + out = self.matmul(x, y) + out = self.mul(out, b) + return out + + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + matmul_in_strategy = ((2.0, 2), (2, 2)) + matmul_out_strategy = ((2, 2),) + mul_strategy = ((4, 2), (4, 2)) + + with pytest.raises(TypeError): + GradWrap(NetWithLoss(Net(matmul_in_strategy, matmul_out_strategy, mul_strategy))) + + +def test_matmul_out_strategy_not_int(): + """ + Feature: the type of out_strategy's value is not int + Description: + Expectation: rasise TypeError + """ + class Net(nn.Cell): + def __init__(self, matmul_in_strategy, matmul_out_strategy, mul_strategy): + super().__init__() + self.matmul = P.MatMul(transpose_b=True).shard(matmul_in_strategy, matmul_out_strategy) + self.mul = P.Mul().shard(mul_strategy) + + def construct(self, x, y, b): + out = self.matmul(x, y) + out = self.mul(out, b) + return out + + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + matmul_in_strategy = ((2, 2), (2, 2)) + matmul_out_strategy = ((2.0, 2),) + mul_strategy = ((4, 2), (4, 2)) + + with pytest.raises(TypeError): + GradWrap(NetWithLoss(Net(matmul_in_strategy, matmul_out_strategy, mul_strategy))) + + +def test_matmul_in_strategy_is_none_and_out_strategy_is_not_none(): + """ + Feature: the in_strategy is none and out_strategy is not none + Description: + Expectation: rasise ValueError + """ + class Net(nn.Cell): + def __init__(self, matmul_in_strategy, matmul_out_strategy, mul_strategy): + super().__init__() + self.matmul = P.MatMul(transpose_b=True).shard(matmul_in_strategy, matmul_out_strategy) + self.mul = P.Mul().shard(mul_strategy) + + def construct(self, x, y, b): + out = self.matmul(x, y) + out = self.mul(out, b) + return out + + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + matmul_in_strategy = None + matmul_out_strategy = ((2, 2),) + mul_strategy = ((4, 2), (4, 2)) + + with pytest.raises(ValueError): + GradWrap(NetWithLoss(Net(matmul_in_strategy, matmul_out_strategy, mul_strategy)))