check args for shard
This commit is contained in:
parent
becf381908
commit
7454b8f8f2
|
@ -177,12 +177,26 @@ class Primitive(Primitive_):
|
||||||
for in_ele in in_strategy:
|
for in_ele in in_strategy:
|
||||||
if not isinstance(in_ele, tuple):
|
if not isinstance(in_ele, tuple):
|
||||||
raise TypeError(f'The element of strategy must be tuple type, but got:{type(in_ele)}')
|
raise TypeError(f'The element of strategy must be tuple type, but got:{type(in_ele)}')
|
||||||
|
for in_value in in_ele:
|
||||||
|
if not isinstance(in_value, int):
|
||||||
|
raise TypeError(f'The in_strategy: {in_strategy} of {self.name} is not valid,'
|
||||||
|
f' the value of strategy must be int type, but got:{type(in_value)}')
|
||||||
|
|
||||||
if out_strategy is not None:
|
if out_strategy is not None:
|
||||||
if not isinstance(out_strategy, tuple):
|
if not isinstance(out_strategy, tuple):
|
||||||
raise TypeError(f'out strategy must be tuple type, but got:{type(out_strategy)}')
|
raise TypeError(f'out strategy must be tuple type, but got:{type(out_strategy)}')
|
||||||
for out_ele in out_strategy:
|
for out_ele in out_strategy:
|
||||||
if not isinstance(out_ele, tuple):
|
if not isinstance(out_ele, tuple):
|
||||||
raise TypeError(f'The element of strategy must be tuple type, but got:{type(out_ele)}')
|
raise TypeError(f'The element of strategy must be tuple type, but got:{type(out_ele)}')
|
||||||
|
for out_value in out_ele:
|
||||||
|
if not isinstance(out_value, int):
|
||||||
|
raise TypeError(f'The in_strategy: {out_strategy} of {self.name} is not valid,'
|
||||||
|
f' the value of strategy must be int type, but got:{type(out_value)}')
|
||||||
|
|
||||||
|
if in_strategy is None and out_strategy is not None:
|
||||||
|
raise ValueError(f'The out_strategy of {self.name} is {out_strategy}, need to set in_strategy,'
|
||||||
|
f' but got none')
|
||||||
|
|
||||||
if not _is_in_auto_parallel_mode():
|
if not _is_in_auto_parallel_mode():
|
||||||
if in_strategy is not None:
|
if in_strategy is not None:
|
||||||
logger.warning(f"The in_strategy: {in_strategy} of {self.name} is not valid in {mode}. "
|
logger.warning(f"The in_strategy: {in_strategy} of {self.name} is not valid in {mode}. "
|
||||||
|
@ -190,6 +204,7 @@ class Primitive(Primitive_):
|
||||||
if out_strategy is not None:
|
if out_strategy is not None:
|
||||||
logger.warning(f"The out_strategy: {out_strategy} of {self.name} is not valid in {mode}. "
|
logger.warning(f"The out_strategy: {out_strategy} of {self.name} is not valid in {mode}. "
|
||||||
f"Please use semi auto or auto parallel mode.")
|
f"Please use semi auto or auto parallel mode.")
|
||||||
|
|
||||||
self.add_prim_attr("in_strategy", in_strategy)
|
self.add_prim_attr("in_strategy", in_strategy)
|
||||||
self.add_prim_attr("out_strategy", out_strategy)
|
self.add_prim_attr("out_strategy", out_strategy)
|
||||||
return self
|
return self
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
|
@ -357,3 +358,81 @@ def test_matmul_output_strategy_all_reduce_transpose_repeat_calc():
|
||||||
y = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
y = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||||
b = Tensor(np.ones([128, 64]), dtype=ms.float32)
|
b = Tensor(np.ones([128, 64]), dtype=ms.float32)
|
||||||
compile_net(net, x, y, b)
|
compile_net(net, x, y, b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_matmul_in_strategy_not_int():
|
||||||
|
"""
|
||||||
|
Feature: the type of in_strategy's value is not int
|
||||||
|
Description:
|
||||||
|
Expectation: rasise TypeError
|
||||||
|
"""
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, matmul_in_strategy, matmul_out_strategy, mul_strategy):
|
||||||
|
super().__init__()
|
||||||
|
self.matmul = P.MatMul(transpose_b=True).shard(matmul_in_strategy, matmul_out_strategy)
|
||||||
|
self.mul = P.Mul().shard(mul_strategy)
|
||||||
|
|
||||||
|
def construct(self, x, y, b):
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
out = self.mul(out, b)
|
||||||
|
return out
|
||||||
|
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||||
|
matmul_in_strategy = ((2.0, 2), (2, 2))
|
||||||
|
matmul_out_strategy = ((2, 2),)
|
||||||
|
mul_strategy = ((4, 2), (4, 2))
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
GradWrap(NetWithLoss(Net(matmul_in_strategy, matmul_out_strategy, mul_strategy)))
|
||||||
|
|
||||||
|
|
||||||
|
def test_matmul_out_strategy_not_int():
|
||||||
|
"""
|
||||||
|
Feature: the type of out_strategy's value is not int
|
||||||
|
Description:
|
||||||
|
Expectation: rasise TypeError
|
||||||
|
"""
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, matmul_in_strategy, matmul_out_strategy, mul_strategy):
|
||||||
|
super().__init__()
|
||||||
|
self.matmul = P.MatMul(transpose_b=True).shard(matmul_in_strategy, matmul_out_strategy)
|
||||||
|
self.mul = P.Mul().shard(mul_strategy)
|
||||||
|
|
||||||
|
def construct(self, x, y, b):
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
out = self.mul(out, b)
|
||||||
|
return out
|
||||||
|
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||||
|
matmul_in_strategy = ((2, 2), (2, 2))
|
||||||
|
matmul_out_strategy = ((2.0, 2),)
|
||||||
|
mul_strategy = ((4, 2), (4, 2))
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
GradWrap(NetWithLoss(Net(matmul_in_strategy, matmul_out_strategy, mul_strategy)))
|
||||||
|
|
||||||
|
|
||||||
|
def test_matmul_in_strategy_is_none_and_out_strategy_is_not_none():
|
||||||
|
"""
|
||||||
|
Feature: the in_strategy is none and out_strategy is not none
|
||||||
|
Description:
|
||||||
|
Expectation: rasise ValueError
|
||||||
|
"""
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, matmul_in_strategy, matmul_out_strategy, mul_strategy):
|
||||||
|
super().__init__()
|
||||||
|
self.matmul = P.MatMul(transpose_b=True).shard(matmul_in_strategy, matmul_out_strategy)
|
||||||
|
self.mul = P.Mul().shard(mul_strategy)
|
||||||
|
|
||||||
|
def construct(self, x, y, b):
|
||||||
|
out = self.matmul(x, y)
|
||||||
|
out = self.mul(out, b)
|
||||||
|
return out
|
||||||
|
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||||
|
matmul_in_strategy = None
|
||||||
|
matmul_out_strategy = ((2, 2),)
|
||||||
|
mul_strategy = ((4, 2), (4, 2))
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
GradWrap(NetWithLoss(Net(matmul_in_strategy, matmul_out_strategy, mul_strategy)))
|
||||||
|
|
Loading…
Reference in New Issue