In sharding propagation, to keep strategy consistent of parameter being used by multiple operators, we check the edge with one node of TmpIdentityInfo

This commit is contained in:
Xiaoda Zhang 2021-11-30 18:58:33 +08:00
parent f9b9e6add6
commit 364858cbc9
4 changed files with 74 additions and 1 deletions

View File

@ -520,8 +520,20 @@ bool Edge::CheckStrategyConsistency(StrategyPtr prev_stra, StrategyPtr next_stra
}
auto cost = GetCostByStrategyPair({prev_stra, next_stra});
if (cost == nullptr || cost->communication_cost_ > 0.0) {
MS_LOG(INFO) << "The edge " << edge_name_ << "'s strategy: ";
PrintStrategy(prev_stra);
PrintStrategy(next_stra);
PrintStrategy(next_stra);
if (prev_op_->IsTmpIdentity()) {
MS_LOG(ERROR) << "The parameter: " << prev_op_->refkey_parameter_name()
<< " has been used by operators with "
"different sharding strategies. These operators are: ";
auto const &succ_edges = prev_op_->succ_edges();
for (auto const &succ_edge : succ_edges) {
MS_LOG(ERROR) << succ_edge->next_operator()->name() << ", the corresponding fullname is: "
<< succ_edge->next_operator()->cnode()->fullname_with_scope();
}
MS_LOG(EXCEPTION) << "Configure these operators with consistent sharding strategies.";
}
MS_LOG(WARNING) << "There are redistribution cost occurs at edge: " << edge_name() << ".";
return false;
}

View File

@ -1428,6 +1428,13 @@ bool OperatorInfo::IsReshape() {
return false;
}
bool OperatorInfo::IsTmpIdentity() {
if (name_.find(IDENTITY_INFO) != std::string::npos) {
return true;
}
return false;
}
// Keep at most (1.0 / epsilon) number of available strategies for each operator.
void OperatorInfo::ApproximateStrategies() {
auto enable_approxi = CostModelContext::GetInstance()->dp_algo_enable_approxi();

View File

@ -151,6 +151,7 @@ class OperatorInfo {
StrategyPtr GetStrategyFromSWCByInputLayout(TensorLayout input_layout, size_t input_index);
StrategyPtr GetStrategyFromSWCByOutputLayout(TensorLayout output_layout, size_t output_index);
bool IsReshape();
bool IsTmpIdentity();
void set_swc_index(int64_t, int64_t);
int64_t swc_index() { return swc_index_; }

View File

@ -334,3 +334,56 @@ def test_reshape_depend_reshape():
net = GradWrapTwoInput(NetWithLoss1(Net()))
with pytest.raises(RuntimeError):
compile_graph_two_input(net, device_num, x, y)
def test_reshape_auto_8():
"""
Feature: Sharding propagation for common parameter being used by multiple ops.
Description: relu->add->mul->mean
Expectation: compile done without error.
"""
device_num = 8
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.gamma = Parameter(Tensor(np.ones([2048, 2048]), dtype=ms.float32), name="gamma")
self.add = P.TensorAdd()
self.relu = P.ReLU().shard(((1, 1),))
self.mul2 = P.MatMul().shard(((1, 1), (1, 8)))
self.mean = P.ReduceMean(keep_dims=True)
def construct(self, x):
out = self.add(x, self.relu(self.gamma))
out = self.mul2(out, self.gamma)
out = self.mean(out, -1)
return out
x = Tensor(np.ones([2048, 2048]), dtype=ms.float32)
net = GradWrap(NetWithLoss(Net()))
with pytest.raises(RuntimeError):
compile_graph(net, device_num, x)
def test_reshape_auto_9():
"""
Feature: Sharding propagation for common parameter being used by multiple ops.
Description: relu->add->mul->mean
Expectation: compile done without error.
"""
device_num = 8
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.gamma = Parameter(Tensor(np.ones([2048, 2048]), dtype=ms.float32), name="gamma")
self.add = P.TensorAdd()
self.relu = P.ReLU().shard(((1, 1),))
self.mul2 = P.MatMul().shard(((8, 1), (1, 1)))
self.mean = P.ReduceMean(keep_dims=True)
def construct(self, x):
out = self.add(x, self.relu(self.gamma))
out = self.mul2(out, self.gamma)
out = self.mean(out, -1)
return out
x = Tensor(np.ones([2048, 2048]), dtype=ms.float32)
net = GradWrap(NetWithLoss(Net()))
compile_graph(net, device_num, x)