forked from mindspore-Ecosystem/mindspore
!48180 Fix bug for LerpInfo with mirror ops when 'value' is float type
Merge pull request !48180 from liuluobin/master_fix_lerp
This commit is contained in:
commit
2424649dd0
|
@ -302,6 +302,21 @@ Status LerpInfo::InferTensorMap() {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status LerpInfo::InferMirrorOps() {
|
||||||
|
if (OperatorInfo::InferMirrorOps() != SUCCESS) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
// No need to insert mirror ops
|
||||||
|
if (mirror_ops_.empty()) {
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
if (mirror_ops_.size() == kSizeTwo) {
|
||||||
|
// Push empty mirror op for value
|
||||||
|
(void)mirror_ops_.emplace_back(OperatorVector());
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<StrategyPtr> LerpInfo::GenerateOpStrategies(int64_t stage_id) {
|
std::vector<StrategyPtr> LerpInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||||
if (inputs_size_ == 2) {
|
if (inputs_size_ == 2) {
|
||||||
return ArithmeticBase::GenerateOpStrategies(stage_id);
|
return ArithmeticBase::GenerateOpStrategies(stage_id);
|
||||||
|
|
|
@ -322,6 +322,7 @@ class LerpInfo : public ArithmeticBase {
|
||||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||||
Status InferDevMatrixShape() override;
|
Status InferDevMatrixShape() override;
|
||||||
Status InferTensorMap() override;
|
Status InferTensorMap() override;
|
||||||
|
Status InferMirrorOps() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
size_t inputs_size_ = 0;
|
size_t inputs_size_ = 0;
|
||||||
|
|
|
@ -973,8 +973,8 @@ static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &no
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mirror_ops.size() != node_size - 1) {
|
if (mirror_ops.size() != node_size - 1) {
|
||||||
MS_LOG(EXCEPTION) << "Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size() << ", node_size is "
|
MS_LOG(EXCEPTION) << "For " << node->fullname_with_scope() << ", the Mirrorops's size is wrong! mirror_ops size is "
|
||||||
<< (node_size - 1);
|
<< mirror_ops.size() << ", node_size is " << (node_size - 1);
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,14 +3,16 @@ import pytest
|
||||||
|
|
||||||
from mindspore import Tensor, context
|
from mindspore import Tensor, context
|
||||||
from mindspore.nn import Cell
|
from mindspore.nn import Cell
|
||||||
|
from mindspore.ops import composite as C
|
||||||
import mindspore.ops as ops
|
import mindspore.ops as ops
|
||||||
|
|
||||||
from parallel.utils.utils import compile_net
|
from parallel.utils.utils import compile_net, ParallelValidator
|
||||||
|
|
||||||
|
|
||||||
def setup_function():
|
def setup_function():
|
||||||
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
context.set_auto_parallel_context(dataset_strategy="full_batch")
|
||||||
|
|
||||||
|
|
||||||
input_start_ = Tensor(np.random.normal(size=[8, 8, 8]).astype(np.float32))
|
input_start_ = Tensor(np.random.normal(size=[8, 8, 8]).astype(np.float32))
|
||||||
input_end_ = Tensor(np.random.normal(size=[8]).astype(np.float32))
|
input_end_ = Tensor(np.random.normal(size=[8]).astype(np.float32))
|
||||||
input_weight_tensor_ = Tensor(np.random.normal(size=[8, 8]).astype(np.float32))
|
input_weight_tensor_ = Tensor(np.random.normal(size=[8, 8]).astype(np.float32))
|
||||||
|
@ -27,6 +29,27 @@ class Net(Cell):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class NetWithWeightFloat(Cell):
|
||||||
|
def __init__(self, weight, strategy=None):
|
||||||
|
super(NetWithWeightFloat, self).__init__()
|
||||||
|
self.weight = weight
|
||||||
|
self.lerp = ops.Lerp().shard(strategy)
|
||||||
|
|
||||||
|
def construct(self, *inputs):
|
||||||
|
output = self.lerp(*inputs, self.weight)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class GradWrap(Cell):
|
||||||
|
def __init__(self, network):
|
||||||
|
super(GradWrap, self).__init__()
|
||||||
|
self.network = network
|
||||||
|
self.grad_op = C.GradOperation()
|
||||||
|
|
||||||
|
def construct(self, *inputs):
|
||||||
|
return self.grad_op(self.network)(*inputs)
|
||||||
|
|
||||||
|
|
||||||
def test_lerp_auto_parallel_with_weight_tensor():
|
def test_lerp_auto_parallel_with_weight_tensor():
|
||||||
"""
|
"""
|
||||||
Feature: test Lerp auto parallel
|
Feature: test Lerp auto parallel
|
||||||
|
@ -45,8 +68,8 @@ def test_lerp_auto_parallel_with_weight_float():
|
||||||
Expectation: compile success
|
Expectation: compile success
|
||||||
"""
|
"""
|
||||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
|
||||||
net = Net()
|
net = NetWithWeightFloat(input_weight_float_)
|
||||||
compile_net(net, input_start_, input_end_, input_weight_float_)
|
compile_net(net, input_start_, input_end_)
|
||||||
|
|
||||||
|
|
||||||
def test_lerp_model_parallel_with_weight_tensor():
|
def test_lerp_model_parallel_with_weight_tensor():
|
||||||
|
@ -69,8 +92,8 @@ def test_lerp_model_parallel_with_weight_float():
|
||||||
"""
|
"""
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
strategy = ((2, 2, 2), (2,))
|
strategy = ((2, 2, 2), (2,))
|
||||||
net = Net(strategy)
|
net = NetWithWeightFloat(input_weight_float_, strategy)
|
||||||
compile_net(net, input_start_, input_end_, input_weight_float_)
|
compile_net(net, input_start_, input_end_)
|
||||||
|
|
||||||
|
|
||||||
def test_lerp_model_parallel_repeated_cal_with_weight_tensor():
|
def test_lerp_model_parallel_repeated_cal_with_weight_tensor():
|
||||||
|
@ -81,8 +104,10 @@ def test_lerp_model_parallel_repeated_cal_with_weight_tensor():
|
||||||
"""
|
"""
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
strategy = ((1, 2, 2), (2,), (2, 2))
|
strategy = ((1, 2, 2), (2,), (2, 2))
|
||||||
net = Net(strategy)
|
net = GradWrap(Net(strategy))
|
||||||
compile_net(net, input_start_, input_end_, input_weight_tensor_)
|
phase = compile_net(net, input_start_, input_end_, input_weight_tensor_)
|
||||||
|
validator = ParallelValidator(net, phase)
|
||||||
|
assert validator.check_node_inputs("Lerp-0", ["_VirtualDiv-0", "_VirtualDiv-1", "_VirtualDiv-2"])
|
||||||
|
|
||||||
|
|
||||||
def test_lerp_model_parallel_repeated_cal_with_weight_float():
|
def test_lerp_model_parallel_repeated_cal_with_weight_float():
|
||||||
|
@ -93,8 +118,10 @@ def test_lerp_model_parallel_repeated_cal_with_weight_float():
|
||||||
"""
|
"""
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
strategy = ((1, 2, 2), (2,))
|
strategy = ((1, 2, 2), (2,))
|
||||||
net = Net(strategy)
|
net = GradWrap(NetWithWeightFloat(input_weight_float_, strategy))
|
||||||
compile_net(net, input_start_, input_end_, input_weight_float_)
|
phase = compile_net(net, input_start_, input_end_)
|
||||||
|
validator = ParallelValidator(net, phase)
|
||||||
|
assert validator.check_node_inputs("Lerp-0", ["_VirtualDiv-0", "_VirtualDiv-1", input_weight_float_])
|
||||||
|
|
||||||
|
|
||||||
def test_lerp_data_parallel_with_weight_tensor():
|
def test_lerp_data_parallel_with_weight_tensor():
|
||||||
|
@ -104,8 +131,7 @@ def test_lerp_data_parallel_with_weight_tensor():
|
||||||
Expectation: compile success
|
Expectation: compile success
|
||||||
"""
|
"""
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
strategy = ((8, 1, 1), (1,), (1, 1))
|
net = Net()
|
||||||
net = Net(strategy)
|
|
||||||
compile_net(net, input_start_, input_end_, input_weight_tensor_)
|
compile_net(net, input_start_, input_end_, input_weight_tensor_)
|
||||||
|
|
||||||
|
|
||||||
|
@ -116,9 +142,8 @@ def test_lerp_data_parallel_with_weight_float():
|
||||||
Expectation: compile success
|
Expectation: compile success
|
||||||
"""
|
"""
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
strategy = ((8, 1, 1), (1,))
|
net = NetWithWeightFloat(input_weight_float_)
|
||||||
net = Net(strategy)
|
compile_net(net, input_start_, input_end_)
|
||||||
compile_net(net, input_start_, input_end_, input_weight_float_)
|
|
||||||
|
|
||||||
|
|
||||||
def test_lerp_strategy_error_with_weight_tensor():
|
def test_lerp_strategy_error_with_weight_tensor():
|
||||||
|
@ -142,6 +167,6 @@ def test_lerp_strategy_error_with_weight_float():
|
||||||
"""
|
"""
|
||||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||||
strategy = ((4, 1, 2), (1,))
|
strategy = ((4, 1, 2), (1,))
|
||||||
net = Net(strategy)
|
net = NetWithWeightFloat(input_weight_float_, strategy)
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
compile_net(net, input_start_, input_end_, input_weight_float_)
|
compile_net(net, input_start_, input_end_)
|
||||||
|
|
Loading…
Reference in New Issue