forked from mindspore-Ecosystem/mindspore
!48180 Fix bug for LerpInfo with mirror ops when 'value' is float type
Merge pull request !48180 from liuluobin/master_fix_lerp
This commit is contained in:
commit
2424649dd0
|
@ -302,6 +302,21 @@ Status LerpInfo::InferTensorMap() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status LerpInfo::InferMirrorOps() {
|
||||
if (OperatorInfo::InferMirrorOps() != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
// No need to insert mirror ops
|
||||
if (mirror_ops_.empty()) {
|
||||
return SUCCESS;
|
||||
}
|
||||
if (mirror_ops_.size() == kSizeTwo) {
|
||||
// Push empty mirror op for value
|
||||
(void)mirror_ops_.emplace_back(OperatorVector());
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<StrategyPtr> LerpInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
if (inputs_size_ == 2) {
|
||||
return ArithmeticBase::GenerateOpStrategies(stage_id);
|
||||
|
|
|
@ -322,6 +322,7 @@ class LerpInfo : public ArithmeticBase {
|
|||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferDevMatrixShape() override;
|
||||
Status InferTensorMap() override;
|
||||
Status InferMirrorOps() override;
|
||||
|
||||
private:
|
||||
size_t inputs_size_ = 0;
|
||||
|
|
|
@ -973,8 +973,8 @@ static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &no
|
|||
}
|
||||
|
||||
if (mirror_ops.size() != node_size - 1) {
|
||||
MS_LOG(EXCEPTION) << "Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size() << ", node_size is "
|
||||
<< (node_size - 1);
|
||||
MS_LOG(EXCEPTION) << "For " << node->fullname_with_scope() << ", the Mirrorops's size is wrong! mirror_ops size is "
|
||||
<< mirror_ops.size() << ", node_size is " << (node_size - 1);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -3,14 +3,16 @@ import pytest
|
|||
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.ops import composite as C
|
||||
import mindspore.ops as ops
|
||||
|
||||
from parallel.utils.utils import compile_net
|
||||
from parallel.utils.utils import compile_net, ParallelValidator
|
||||
|
||||
|
||||
def setup_function():
|
||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||
|
||||
|
||||
input_start_ = Tensor(np.random.normal(size=[8, 8, 8]).astype(np.float32))
|
||||
input_end_ = Tensor(np.random.normal(size=[8]).astype(np.float32))
|
||||
input_weight_tensor_ = Tensor(np.random.normal(size=[8, 8]).astype(np.float32))
|
||||
|
@ -27,6 +29,27 @@ class Net(Cell):
|
|||
return output
|
||||
|
||||
|
||||
class NetWithWeightFloat(Cell):
|
||||
def __init__(self, weight, strategy=None):
|
||||
super(NetWithWeightFloat, self).__init__()
|
||||
self.weight = weight
|
||||
self.lerp = ops.Lerp().shard(strategy)
|
||||
|
||||
def construct(self, *inputs):
|
||||
output = self.lerp(*inputs, self.weight)
|
||||
return output
|
||||
|
||||
|
||||
class GradWrap(Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
self.grad_op = C.GradOperation()
|
||||
|
||||
def construct(self, *inputs):
|
||||
return self.grad_op(self.network)(*inputs)
|
||||
|
||||
|
||||
def test_lerp_auto_parallel_with_weight_tensor():
|
||||
"""
|
||||
Feature: test Lerp auto parallel
|
||||
|
@ -45,8 +68,8 @@ def test_lerp_auto_parallel_with_weight_float():
|
|||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||
net = Net()
|
||||
compile_net(net, input_start_, input_end_, input_weight_float_)
|
||||
net = NetWithWeightFloat(input_weight_float_)
|
||||
compile_net(net, input_start_, input_end_)
|
||||
|
||||
|
||||
def test_lerp_model_parallel_with_weight_tensor():
|
||||
|
@ -69,8 +92,8 @@ def test_lerp_model_parallel_with_weight_float():
|
|||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((2, 2, 2), (2,))
|
||||
net = Net(strategy)
|
||||
compile_net(net, input_start_, input_end_, input_weight_float_)
|
||||
net = NetWithWeightFloat(input_weight_float_, strategy)
|
||||
compile_net(net, input_start_, input_end_)
|
||||
|
||||
|
||||
def test_lerp_model_parallel_repeated_cal_with_weight_tensor():
|
||||
|
@ -81,8 +104,10 @@ def test_lerp_model_parallel_repeated_cal_with_weight_tensor():
|
|||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((1, 2, 2), (2,), (2, 2))
|
||||
net = Net(strategy)
|
||||
compile_net(net, input_start_, input_end_, input_weight_tensor_)
|
||||
net = GradWrap(Net(strategy))
|
||||
phase = compile_net(net, input_start_, input_end_, input_weight_tensor_)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_node_inputs("Lerp-0", ["_VirtualDiv-0", "_VirtualDiv-1", "_VirtualDiv-2"])
|
||||
|
||||
|
||||
def test_lerp_model_parallel_repeated_cal_with_weight_float():
|
||||
|
@ -93,8 +118,10 @@ def test_lerp_model_parallel_repeated_cal_with_weight_float():
|
|||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((1, 2, 2), (2,))
|
||||
net = Net(strategy)
|
||||
compile_net(net, input_start_, input_end_, input_weight_float_)
|
||||
net = GradWrap(NetWithWeightFloat(input_weight_float_, strategy))
|
||||
phase = compile_net(net, input_start_, input_end_)
|
||||
validator = ParallelValidator(net, phase)
|
||||
assert validator.check_node_inputs("Lerp-0", ["_VirtualDiv-0", "_VirtualDiv-1", input_weight_float_])
|
||||
|
||||
|
||||
def test_lerp_data_parallel_with_weight_tensor():
|
||||
|
@ -104,8 +131,7 @@ def test_lerp_data_parallel_with_weight_tensor():
|
|||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((8, 1, 1), (1,), (1, 1))
|
||||
net = Net(strategy)
|
||||
net = Net()
|
||||
compile_net(net, input_start_, input_end_, input_weight_tensor_)
|
||||
|
||||
|
||||
|
@ -116,9 +142,8 @@ def test_lerp_data_parallel_with_weight_float():
|
|||
Expectation: compile success
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((8, 1, 1), (1,))
|
||||
net = Net(strategy)
|
||||
compile_net(net, input_start_, input_end_, input_weight_float_)
|
||||
net = NetWithWeightFloat(input_weight_float_)
|
||||
compile_net(net, input_start_, input_end_)
|
||||
|
||||
|
||||
def test_lerp_strategy_error_with_weight_tensor():
|
||||
|
@ -142,6 +167,6 @@ def test_lerp_strategy_error_with_weight_float():
|
|||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy = ((4, 1, 2), (1,))
|
||||
net = Net(strategy)
|
||||
net = NetWithWeightFloat(input_weight_float_, strategy)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net, input_start_, input_end_, input_weight_float_)
|
||||
compile_net(net, input_start_, input_end_)
|
||||
|
|
Loading…
Reference in New Issue