!48180 Fix bug for LerpInfo with mirror ops when 'value' is float type

Merge pull request !48180 from liuluobin/master_fix_lerp
This commit is contained in:
i-robot 2023-02-09 08:39:01 +00:00 committed by Gitee
commit 2424649dd0
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 59 additions and 18 deletions

View File

@ -302,6 +302,21 @@ Status LerpInfo::InferTensorMap() {
return SUCCESS;
}
Status LerpInfo::InferMirrorOps() {
if (OperatorInfo::InferMirrorOps() != SUCCESS) {
return FAILED;
}
// No need to insert mirror ops
if (mirror_ops_.empty()) {
return SUCCESS;
}
if (mirror_ops_.size() == kSizeTwo) {
// Push empty mirror op for value
(void)mirror_ops_.emplace_back(OperatorVector());
}
return SUCCESS;
}
std::vector<StrategyPtr> LerpInfo::GenerateOpStrategies(int64_t stage_id) {
if (inputs_size_ == 2) {
return ArithmeticBase::GenerateOpStrategies(stage_id);

View File

@ -322,6 +322,7 @@ class LerpInfo : public ArithmeticBase {
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status InferMirrorOps() override;
private:
size_t inputs_size_ = 0;

View File

@ -973,8 +973,8 @@ static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &no
}
if (mirror_ops.size() != node_size - 1) {
MS_LOG(EXCEPTION) << "Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size() << ", node_size is "
<< (node_size - 1);
MS_LOG(EXCEPTION) << "For " << node->fullname_with_scope() << ", the Mirrorops's size is wrong! mirror_ops size is "
<< mirror_ops.size() << ", node_size is " << (node_size - 1);
}
return true;
}

View File

@ -3,14 +3,16 @@ import pytest
from mindspore import Tensor, context
from mindspore.nn import Cell
from mindspore.ops import composite as C
import mindspore.ops as ops
from parallel.utils.utils import compile_net
from parallel.utils.utils import compile_net, ParallelValidator
def setup_function():
context.set_auto_parallel_context(dataset_strategy="full_batch")
input_start_ = Tensor(np.random.normal(size=[8, 8, 8]).astype(np.float32))
input_end_ = Tensor(np.random.normal(size=[8]).astype(np.float32))
input_weight_tensor_ = Tensor(np.random.normal(size=[8, 8]).astype(np.float32))
@ -27,6 +29,27 @@ class Net(Cell):
return output
class NetWithWeightFloat(Cell):
def __init__(self, weight, strategy=None):
super(NetWithWeightFloat, self).__init__()
self.weight = weight
self.lerp = ops.Lerp().shard(strategy)
def construct(self, *inputs):
output = self.lerp(*inputs, self.weight)
return output
class GradWrap(Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
self.grad_op = C.GradOperation()
def construct(self, *inputs):
return self.grad_op(self.network)(*inputs)
def test_lerp_auto_parallel_with_weight_tensor():
"""
Feature: test Lerp auto parallel
@ -45,8 +68,8 @@ def test_lerp_auto_parallel_with_weight_float():
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net()
compile_net(net, input_start_, input_end_, input_weight_float_)
net = NetWithWeightFloat(input_weight_float_)
compile_net(net, input_start_, input_end_)
def test_lerp_model_parallel_with_weight_tensor():
@ -69,8 +92,8 @@ def test_lerp_model_parallel_with_weight_float():
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy = ((2, 2, 2), (2,))
net = Net(strategy)
compile_net(net, input_start_, input_end_, input_weight_float_)
net = NetWithWeightFloat(input_weight_float_, strategy)
compile_net(net, input_start_, input_end_)
def test_lerp_model_parallel_repeated_cal_with_weight_tensor():
@ -81,8 +104,10 @@ def test_lerp_model_parallel_repeated_cal_with_weight_tensor():
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy = ((1, 2, 2), (2,), (2, 2))
net = Net(strategy)
compile_net(net, input_start_, input_end_, input_weight_tensor_)
net = GradWrap(Net(strategy))
phase = compile_net(net, input_start_, input_end_, input_weight_tensor_)
validator = ParallelValidator(net, phase)
assert validator.check_node_inputs("Lerp-0", ["_VirtualDiv-0", "_VirtualDiv-1", "_VirtualDiv-2"])
def test_lerp_model_parallel_repeated_cal_with_weight_float():
@ -93,8 +118,10 @@ def test_lerp_model_parallel_repeated_cal_with_weight_float():
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy = ((1, 2, 2), (2,))
net = Net(strategy)
compile_net(net, input_start_, input_end_, input_weight_float_)
net = GradWrap(NetWithWeightFloat(input_weight_float_, strategy))
phase = compile_net(net, input_start_, input_end_)
validator = ParallelValidator(net, phase)
assert validator.check_node_inputs("Lerp-0", ["_VirtualDiv-0", "_VirtualDiv-1", input_weight_float_])
def test_lerp_data_parallel_with_weight_tensor():
@ -104,8 +131,7 @@ def test_lerp_data_parallel_with_weight_tensor():
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy = ((8, 1, 1), (1,), (1, 1))
net = Net(strategy)
net = Net()
compile_net(net, input_start_, input_end_, input_weight_tensor_)
@ -116,9 +142,8 @@ def test_lerp_data_parallel_with_weight_float():
Expectation: compile success
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy = ((8, 1, 1), (1,))
net = Net(strategy)
compile_net(net, input_start_, input_end_, input_weight_float_)
net = NetWithWeightFloat(input_weight_float_)
compile_net(net, input_start_, input_end_)
def test_lerp_strategy_error_with_weight_tensor():
@ -142,6 +167,6 @@ def test_lerp_strategy_error_with_weight_float():
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy = ((4, 1, 2), (1,))
net = Net(strategy)
net = NetWithWeightFloat(input_weight_float_, strategy)
with pytest.raises(RuntimeError):
compile_net(net, input_start_, input_end_, input_weight_float_)
compile_net(net, input_start_, input_end_)