diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc index 931201319b5..6d35bf7274c 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc @@ -302,6 +302,21 @@ Status LerpInfo::InferTensorMap() { return SUCCESS; } +Status LerpInfo::InferMirrorOps() { + if (OperatorInfo::InferMirrorOps() != SUCCESS) { + return FAILED; + } + // No need to insert mirror ops + if (mirror_ops_.empty()) { + return SUCCESS; + } + if (mirror_ops_.size() == kSizeTwo) { + // Push empty mirror op for value + (void)mirror_ops_.emplace_back(OperatorVector()); + } + return SUCCESS; +} + std::vector LerpInfo::GenerateOpStrategies(int64_t stage_id) { if (inputs_size_ == 2) { return ArithmeticBase::GenerateOpStrategies(stage_id); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.h index 9aafef15ce6..0dae6a2a68c 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.h @@ -322,6 +322,7 @@ class LerpInfo : public ArithmeticBase { Status CheckStrategy(const StrategyPtr &strategy) override; Status InferDevMatrixShape() override; Status InferTensorMap() override; + Status InferMirrorOps() override; private: size_t inputs_size_ = 0; diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index b2feba4bb9a..1e3858fc2aa 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -973,8 +973,8 @@ static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &no } if (mirror_ops.size() != node_size - 1) { - MS_LOG(EXCEPTION) << "Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size() << ", node_size is " - << (node_size - 1); + MS_LOG(EXCEPTION) << "For " << node->fullname_with_scope() << ", the Mirrorops's size is wrong! mirror_ops size is " + << mirror_ops.size() << ", node_size is " << (node_size - 1); } return true; } diff --git a/tests/ut/python/parallel/test_lerp.py b/tests/ut/python/parallel/test_lerp.py index bd7ae785d17..5685085c7cd 100644 --- a/tests/ut/python/parallel/test_lerp.py +++ b/tests/ut/python/parallel/test_lerp.py @@ -3,14 +3,16 @@ import pytest from mindspore import Tensor, context from mindspore.nn import Cell +from mindspore.ops import composite as C import mindspore.ops as ops -from parallel.utils.utils import compile_net +from parallel.utils.utils import compile_net, ParallelValidator def setup_function(): context.set_auto_parallel_context(dataset_strategy="full_batch") + input_start_ = Tensor(np.random.normal(size=[8, 8, 8]).astype(np.float32)) input_end_ = Tensor(np.random.normal(size=[8]).astype(np.float32)) input_weight_tensor_ = Tensor(np.random.normal(size=[8, 8]).astype(np.float32)) @@ -27,6 +29,27 @@ class Net(Cell): return output +class NetWithWeightFloat(Cell): + def __init__(self, weight, strategy=None): + super(NetWithWeightFloat, self).__init__() + self.weight = weight + self.lerp = ops.Lerp().shard(strategy) + + def construct(self, *inputs): + output = self.lerp(*inputs, self.weight) + return output + + +class GradWrap(Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + self.grad_op = C.GradOperation() + + def construct(self, *inputs): + return self.grad_op(self.network)(*inputs) + + def test_lerp_auto_parallel_with_weight_tensor(): """ Feature: test Lerp auto parallel @@ -45,8 +68,8 @@ def test_lerp_auto_parallel_with_weight_float(): Expectation: compile success """ context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) - net = Net() - compile_net(net, input_start_, input_end_, input_weight_float_) + net = NetWithWeightFloat(input_weight_float_) + compile_net(net, input_start_, input_end_) def test_lerp_model_parallel_with_weight_tensor(): @@ -69,8 +92,8 @@ def test_lerp_model_parallel_with_weight_float(): """ context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) strategy = ((2, 2, 2), (2,)) - net = Net(strategy) - compile_net(net, input_start_, input_end_, input_weight_float_) + net = NetWithWeightFloat(input_weight_float_, strategy) + compile_net(net, input_start_, input_end_) def test_lerp_model_parallel_repeated_cal_with_weight_tensor(): @@ -81,8 +104,10 @@ def test_lerp_model_parallel_repeated_cal_with_weight_tensor(): """ context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) strategy = ((1, 2, 2), (2,), (2, 2)) - net = Net(strategy) - compile_net(net, input_start_, input_end_, input_weight_tensor_) + net = GradWrap(Net(strategy)) + phase = compile_net(net, input_start_, input_end_, input_weight_tensor_) + validator = ParallelValidator(net, phase) + assert validator.check_node_inputs("Lerp-0", ["_VirtualDiv-0", "_VirtualDiv-1", "_VirtualDiv-2"]) def test_lerp_model_parallel_repeated_cal_with_weight_float(): @@ -93,8 +118,10 @@ def test_lerp_model_parallel_repeated_cal_with_weight_float(): """ context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) strategy = ((1, 2, 2), (2,)) - net = Net(strategy) - compile_net(net, input_start_, input_end_, input_weight_float_) + net = GradWrap(NetWithWeightFloat(input_weight_float_, strategy)) + phase = compile_net(net, input_start_, input_end_) + validator = ParallelValidator(net, phase) + assert validator.check_node_inputs("Lerp-0", ["_VirtualDiv-0", "_VirtualDiv-1", input_weight_float_]) def test_lerp_data_parallel_with_weight_tensor(): @@ -104,8 +131,7 @@ def test_lerp_data_parallel_with_weight_tensor(): Expectation: compile success """ context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) - strategy = ((8, 1, 1), (1,), (1, 1)) - net = Net(strategy) + net = Net() compile_net(net, input_start_, input_end_, input_weight_tensor_) @@ -116,9 +142,8 @@ def test_lerp_data_parallel_with_weight_float(): Expectation: compile success """ context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) - strategy = ((8, 1, 1), (1,)) - net = Net(strategy) - compile_net(net, input_start_, input_end_, input_weight_float_) + net = NetWithWeightFloat(input_weight_float_) + compile_net(net, input_start_, input_end_) def test_lerp_strategy_error_with_weight_tensor(): @@ -142,6 +167,6 @@ def test_lerp_strategy_error_with_weight_float(): """ context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) strategy = ((4, 1, 2), (1,)) - net = Net(strategy) + net = NetWithWeightFloat(input_weight_float_, strategy) with pytest.raises(RuntimeError): - compile_net(net, input_start_, input_end_, input_weight_float_) + compile_net(net, input_start_, input_end_)