check args for shard

This commit is contained in:
yangzhenzhang 2021-12-02 11:26:27 +08:00
parent becf381908
commit 7454b8f8f2
2 changed files with 94 additions and 0 deletions

View File

@ -177,12 +177,26 @@ class Primitive(Primitive_):
for in_ele in in_strategy:
if not isinstance(in_ele, tuple):
raise TypeError(f'The element of strategy must be tuple type, but got:{type(in_ele)}')
for in_value in in_ele:
if not isinstance(in_value, int):
raise TypeError(f'The in_strategy: {in_strategy} of {self.name} is not valid,'
f' the value of strategy must be int type, but got:{type(in_value)}')
if out_strategy is not None:
if not isinstance(out_strategy, tuple):
raise TypeError(f'out strategy must be tuple type, but got:{type(out_strategy)}')
for out_ele in out_strategy:
if not isinstance(out_ele, tuple):
raise TypeError(f'The element of strategy must be tuple type, but got:{type(out_ele)}')
for out_value in out_ele:
if not isinstance(out_value, int):
raise TypeError(f'The in_strategy: {out_strategy} of {self.name} is not valid,'
f' the value of strategy must be int type, but got:{type(out_value)}')
if in_strategy is None and out_strategy is not None:
raise ValueError(f'The out_strategy of {self.name} is {out_strategy}, need to set in_strategy,'
f' but got none')
if not _is_in_auto_parallel_mode():
if in_strategy is not None:
logger.warning(f"The in_strategy: {in_strategy} of {self.name} is not valid in {mode}. "
@ -190,6 +204,7 @@ class Primitive(Primitive_):
if out_strategy is not None:
logger.warning(f"The out_strategy: {out_strategy} of {self.name} is not valid in {mode}. "
f"Please use semi auto or auto parallel mode.")
self.add_prim_attr("in_strategy", in_strategy)
self.add_prim_attr("out_strategy", out_strategy)
return self

View File

@ -13,6 +13,7 @@
# limitations under the License.
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
@ -357,3 +358,81 @@ def test_matmul_output_strategy_all_reduce_transpose_repeat_calc():
y = Tensor(np.ones([64, 32]), dtype=ms.float32)
b = Tensor(np.ones([128, 64]), dtype=ms.float32)
compile_net(net, x, y, b)
def test_matmul_in_strategy_not_int():
"""
Feature: the type of in_strategy's value is not int
Description:
Expectation: rasise TypeError
"""
class Net(nn.Cell):
def __init__(self, matmul_in_strategy, matmul_out_strategy, mul_strategy):
super().__init__()
self.matmul = P.MatMul(transpose_b=True).shard(matmul_in_strategy, matmul_out_strategy)
self.mul = P.Mul().shard(mul_strategy)
def construct(self, x, y, b):
out = self.matmul(x, y)
out = self.mul(out, b)
return out
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
matmul_in_strategy = ((2.0, 2), (2, 2))
matmul_out_strategy = ((2, 2),)
mul_strategy = ((4, 2), (4, 2))
with pytest.raises(TypeError):
GradWrap(NetWithLoss(Net(matmul_in_strategy, matmul_out_strategy, mul_strategy)))
def test_matmul_out_strategy_not_int():
"""
Feature: the type of out_strategy's value is not int
Description:
Expectation: rasise TypeError
"""
class Net(nn.Cell):
def __init__(self, matmul_in_strategy, matmul_out_strategy, mul_strategy):
super().__init__()
self.matmul = P.MatMul(transpose_b=True).shard(matmul_in_strategy, matmul_out_strategy)
self.mul = P.Mul().shard(mul_strategy)
def construct(self, x, y, b):
out = self.matmul(x, y)
out = self.mul(out, b)
return out
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
matmul_in_strategy = ((2, 2), (2, 2))
matmul_out_strategy = ((2.0, 2),)
mul_strategy = ((4, 2), (4, 2))
with pytest.raises(TypeError):
GradWrap(NetWithLoss(Net(matmul_in_strategy, matmul_out_strategy, mul_strategy)))
def test_matmul_in_strategy_is_none_and_out_strategy_is_not_none():
"""
Feature: the in_strategy is none and out_strategy is not none
Description:
Expectation: rasise ValueError
"""
class Net(nn.Cell):
def __init__(self, matmul_in_strategy, matmul_out_strategy, mul_strategy):
super().__init__()
self.matmul = P.MatMul(transpose_b=True).shard(matmul_in_strategy, matmul_out_strategy)
self.mul = P.Mul().shard(mul_strategy)
def construct(self, x, y, b):
out = self.matmul(x, y)
out = self.mul(out, b)
return out
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
matmul_in_strategy = None
matmul_out_strategy = ((2, 2),)
mul_strategy = ((4, 2), (4, 2))
with pytest.raises(ValueError):
GradWrap(NetWithLoss(Net(matmul_in_strategy, matmul_out_strategy, mul_strategy)))