forked from mindspore-Ecosystem/mindspore
In sharding propagation, to keep strategy consistent of parameter being used by multiple operators, we check the edge with one node of TmpIdentityInfo
This commit is contained in:
parent
f9b9e6add6
commit
364858cbc9
|
@ -520,8 +520,20 @@ bool Edge::CheckStrategyConsistency(StrategyPtr prev_stra, StrategyPtr next_stra
|
|||
}
|
||||
auto cost = GetCostByStrategyPair({prev_stra, next_stra});
|
||||
if (cost == nullptr || cost->communication_cost_ > 0.0) {
|
||||
MS_LOG(INFO) << "The edge " << edge_name_ << "'s strategy: ";
|
||||
PrintStrategy(prev_stra);
|
||||
PrintStrategy(next_stra);
|
||||
PrintStrategy(next_stra);
|
||||
if (prev_op_->IsTmpIdentity()) {
|
||||
MS_LOG(ERROR) << "The parameter: " << prev_op_->refkey_parameter_name()
|
||||
<< " has been used by operators with "
|
||||
"different sharding strategies. These operators are: ";
|
||||
auto const &succ_edges = prev_op_->succ_edges();
|
||||
for (auto const &succ_edge : succ_edges) {
|
||||
MS_LOG(ERROR) << succ_edge->next_operator()->name() << ", the corresponding fullname is: "
|
||||
<< succ_edge->next_operator()->cnode()->fullname_with_scope();
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Configure these operators with consistent sharding strategies.";
|
||||
}
|
||||
MS_LOG(WARNING) << "There are redistribution cost occurs at edge: " << edge_name() << ".";
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -1428,6 +1428,13 @@ bool OperatorInfo::IsReshape() {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool OperatorInfo::IsTmpIdentity() {
|
||||
if (name_.find(IDENTITY_INFO) != std::string::npos) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Keep at most (1.0 / epsilon) number of available strategies for each operator.
|
||||
void OperatorInfo::ApproximateStrategies() {
|
||||
auto enable_approxi = CostModelContext::GetInstance()->dp_algo_enable_approxi();
|
||||
|
|
|
@ -151,6 +151,7 @@ class OperatorInfo {
|
|||
StrategyPtr GetStrategyFromSWCByInputLayout(TensorLayout input_layout, size_t input_index);
|
||||
StrategyPtr GetStrategyFromSWCByOutputLayout(TensorLayout output_layout, size_t output_index);
|
||||
bool IsReshape();
|
||||
bool IsTmpIdentity();
|
||||
|
||||
void set_swc_index(int64_t, int64_t);
|
||||
int64_t swc_index() { return swc_index_; }
|
||||
|
|
|
@ -334,3 +334,56 @@ def test_reshape_depend_reshape():
|
|||
net = GradWrapTwoInput(NetWithLoss1(Net()))
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_graph_two_input(net, device_num, x, y)
|
||||
|
||||
def test_reshape_auto_8():
|
||||
"""
|
||||
Feature: Sharding propagation for common parameter being used by multiple ops.
|
||||
Description: relu->add->mul->mean
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
device_num = 8
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gamma = Parameter(Tensor(np.ones([2048, 2048]), dtype=ms.float32), name="gamma")
|
||||
self.add = P.TensorAdd()
|
||||
self.relu = P.ReLU().shard(((1, 1),))
|
||||
self.mul2 = P.MatMul().shard(((1, 1), (1, 8)))
|
||||
self.mean = P.ReduceMean(keep_dims=True)
|
||||
|
||||
def construct(self, x):
|
||||
out = self.add(x, self.relu(self.gamma))
|
||||
out = self.mul2(out, self.gamma)
|
||||
out = self.mean(out, -1)
|
||||
return out
|
||||
|
||||
x = Tensor(np.ones([2048, 2048]), dtype=ms.float32)
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_graph(net, device_num, x)
|
||||
|
||||
def test_reshape_auto_9():
|
||||
"""
|
||||
Feature: Sharding propagation for common parameter being used by multiple ops.
|
||||
Description: relu->add->mul->mean
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
device_num = 8
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gamma = Parameter(Tensor(np.ones([2048, 2048]), dtype=ms.float32), name="gamma")
|
||||
self.add = P.TensorAdd()
|
||||
self.relu = P.ReLU().shard(((1, 1),))
|
||||
self.mul2 = P.MatMul().shard(((8, 1), (1, 1)))
|
||||
self.mean = P.ReduceMean(keep_dims=True)
|
||||
|
||||
def construct(self, x):
|
||||
out = self.add(x, self.relu(self.gamma))
|
||||
out = self.mul2(out, self.gamma)
|
||||
out = self.mean(out, -1)
|
||||
return out
|
||||
|
||||
x = Tensor(np.ones([2048, 2048]), dtype=ms.float32)
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
compile_graph(net, device_num, x)
|
||||
|
|
Loading…
Reference in New Issue