forked from mindspore-Ecosystem/mindspore
!26494 [Auto parallel] Adjusting sharding propagation
Merge pull request !26494 from Xiaoda/102-adjusting-sharding-propagation
This commit is contained in:
commit
7559d5b798
|
@ -349,8 +349,9 @@ CostPtr Edge::GetCostByStrategyPair(const CostPtrKey &stra_pair) {
|
|||
return cost_vec[0];
|
||||
}
|
||||
|
||||
StrategyPtr Edge::GetNextOpStrategyByPrevOpStrategyWithZeroComm(const StrategyPtr &prev_op_stra) {
|
||||
StrategyPtr Edge::GetNextOpStrategyByPrevOpStrategyWithMiniComm(const StrategyPtr &prev_op_stra) {
|
||||
std::vector<std::pair<StrategyPtr, double>> next_op_stras;
|
||||
// First, try to find the strategy with zero communication cost.
|
||||
for (auto &key_value : cost_map_) {
|
||||
const auto &candidate_prev_op_stra = key_value.first.first;
|
||||
if (prev_op_stra->IsEqual(candidate_prev_op_stra) && (key_value.second[0]->communication_cost_ == 0.0)) {
|
||||
|
@ -358,12 +359,28 @@ StrategyPtr Edge::GetNextOpStrategyByPrevOpStrategyWithZeroComm(const StrategyPt
|
|||
}
|
||||
}
|
||||
if (next_op_stras.empty()) {
|
||||
MS_LOG(ERROR) << "There are no available strategy for zero communication cost for edge: " << edge_name_;
|
||||
return nullptr;
|
||||
} else if (next_op_stras.size() > 1) {
|
||||
// Second, if there is not strategy with zero communication cost, find the one with minimum communication cost.
|
||||
std::vector<std::pair<StrategyPtr, double>> next_stras;
|
||||
for (auto &key_value : cost_map_) {
|
||||
const auto &candidate_prev_op_stra = key_value.first.first;
|
||||
if (prev_op_stra->IsEqual(candidate_prev_op_stra)) {
|
||||
(void)next_stras.emplace_back(key_value.first.second, key_value.second[0]->communication_cost_);
|
||||
}
|
||||
}
|
||||
if (next_stras.empty()) {
|
||||
MS_LOG(ERROR) << "There are no available strategy for zero communication cost for edge: " << edge_name_;
|
||||
return nullptr;
|
||||
}
|
||||
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge_name();
|
||||
std::sort(next_stras.begin(), next_stras.end(),
|
||||
[](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
|
||||
return a.second <= b.second;
|
||||
});
|
||||
return next_stras[0].first;
|
||||
}
|
||||
if (next_op_stras.size() > 1) {
|
||||
MS_LOG(INFO) << "There are multiple strategies for edge: " << edge_name_
|
||||
<< ", choose the one with"
|
||||
" minimum computation costs.";
|
||||
<< " with zero communication cost, choose the one with minimum computation costs.";
|
||||
}
|
||||
std::sort(next_op_stras.begin(), next_op_stras.end(),
|
||||
[](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
|
||||
|
@ -372,8 +389,9 @@ StrategyPtr Edge::GetNextOpStrategyByPrevOpStrategyWithZeroComm(const StrategyPt
|
|||
return next_op_stras[0].first;
|
||||
}
|
||||
|
||||
StrategyPtr Edge::GetPrevOpStrategyByNextOpStrategyWithZeroComm(const StrategyPtr &next_op_stra) {
|
||||
StrategyPtr Edge::GetPrevOpStrategyByNextOpStrategyWithMiniComm(const StrategyPtr &next_op_stra) {
|
||||
std::vector<std::pair<StrategyPtr, double>> prev_op_stras;
|
||||
// First, try to find the strategy with zero communication cost.
|
||||
for (auto &key_value : cost_map_) {
|
||||
const auto &candidate_next_op_stra = key_value.first.second;
|
||||
if (next_op_stra->IsEqual(candidate_next_op_stra) && (key_value.second[0]->communication_cost_ == 0.0)) {
|
||||
|
@ -381,12 +399,28 @@ StrategyPtr Edge::GetPrevOpStrategyByNextOpStrategyWithZeroComm(const StrategyPt
|
|||
}
|
||||
}
|
||||
if (prev_op_stras.empty()) {
|
||||
MS_LOG(ERROR) << "There are no available strategy for zero communication cost for edge: " << edge_name_;
|
||||
return nullptr;
|
||||
} else if (prev_op_stras.size() > 1) {
|
||||
// Second, if there is no strategy with zero communication cost, find the one with minimum communication cost.
|
||||
std::vector<std::pair<StrategyPtr, double>> prev_stras;
|
||||
for (auto &key_value : cost_map_) {
|
||||
const auto &candidate_next_op_stra = key_value.first.second;
|
||||
if (next_op_stra->IsEqual(candidate_next_op_stra)) {
|
||||
(void)prev_stras.emplace_back(key_value.first.first, key_value.second[0]->communication_cost_);
|
||||
}
|
||||
}
|
||||
if (prev_stras.empty()) {
|
||||
MS_LOG(ERROR) << "There are no available strategy for zero communication cost for edge: " << edge_name_;
|
||||
return nullptr;
|
||||
}
|
||||
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge_name();
|
||||
std::sort(prev_stras.begin(), prev_stras.end(),
|
||||
[](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
|
||||
return a.second <= b.second;
|
||||
});
|
||||
return prev_stras[0].first;
|
||||
}
|
||||
if (prev_op_stras.size() > 1) {
|
||||
MS_LOG(INFO) << "There are multiple strategies for edge: " << edge_name_
|
||||
<< ", choose the one with minimum "
|
||||
"computation costs.";
|
||||
<< " with zero communication costs, choose the one with minimum computation costs.";
|
||||
}
|
||||
std::sort(prev_op_stras.begin(), prev_op_stras.end(),
|
||||
[](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
|
||||
|
@ -395,12 +429,11 @@ StrategyPtr Edge::GetPrevOpStrategyByNextOpStrategyWithZeroComm(const StrategyPt
|
|||
return prev_op_stras[0].first;
|
||||
}
|
||||
|
||||
int64_t Edge::GetReshapeSWCIndexByNextOpStrategy(const StrategyPtr &next_op_stra, int64_t curr_depth,
|
||||
const std::map<OperatorInfoPtr, StrategyPtr> &configured_ops) {
|
||||
if (prev_op_->name().find(RESHAPEINFO) == std::string::npos) {
|
||||
int64_t Edge::GetReshapeSWCIndexByNextOpStrategy(const StrategyPtr &next_op_stra) {
|
||||
if (!prev_op_->IsReshape()) {
|
||||
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s prev_op is not a Reshape.";
|
||||
}
|
||||
if (next_op_->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
if (next_op_->IsReshape()) {
|
||||
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently.";
|
||||
}
|
||||
const auto &reshape_output_layout = next_op_->GetInputLayoutFromSWCByStrategy(next_op_stra, next_op_input_index_);
|
||||
|
@ -408,15 +441,11 @@ int64_t Edge::GetReshapeSWCIndexByNextOpStrategy(const StrategyPtr &next_op_stra
|
|||
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(prev_op_);
|
||||
// First, try to find the zero communication strategy.
|
||||
auto swc_index = reshape_ptr->GetSWCIndexByOutputLayoutWithZeroComm(reshape_output_layout);
|
||||
if (swc_index == -1 && curr_depth == 0) {
|
||||
const auto &prev_edges = reshape_ptr->prev_edges();
|
||||
if (!prev_edges.empty()) {
|
||||
auto prev_edge = prev_edges[0];
|
||||
if (configured_ops.find(prev_edge->prev_operator()) != configured_ops.end()) {
|
||||
// Here, it is sure that Reshape's previous and next operators are both configured strategies,
|
||||
// thus, it is OK that communication happens here.
|
||||
swc_index = reshape_ptr->GetSWCIndexByOutputLayout(reshape_output_layout);
|
||||
}
|
||||
if (swc_index == -1) {
|
||||
// Second, if there is no strategy with zero communication cost, find the strategy with minimum cost.
|
||||
swc_index = reshape_ptr->GetSWCIndexByOutputLayoutWithMiniComm(reshape_output_layout);
|
||||
if (swc_index != -1) {
|
||||
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge_name();
|
||||
}
|
||||
}
|
||||
if (swc_index == -1) {
|
||||
|
@ -425,12 +454,11 @@ int64_t Edge::GetReshapeSWCIndexByNextOpStrategy(const StrategyPtr &next_op_stra
|
|||
return swc_index;
|
||||
}
|
||||
|
||||
int64_t Edge::GetReshapeSWCIndexByPrevOpStrategy(const StrategyPtr &prev_op_stra, int64_t curr_depth,
|
||||
const std::map<OperatorInfoPtr, StrategyPtr> &configured_ops) {
|
||||
if (next_op_->name().find(RESHAPEINFO) == std::string::npos) {
|
||||
int64_t Edge::GetReshapeSWCIndexByPrevOpStrategy(const StrategyPtr &prev_op_stra) {
|
||||
if (!next_op_->IsReshape()) {
|
||||
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s next_op is not a Reshape.";
|
||||
}
|
||||
if (prev_op_->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
if (prev_op_->IsReshape()) {
|
||||
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently.";
|
||||
}
|
||||
const auto &reshape_input_lyt = prev_op_->GetOutputLayoutFromSWCByStrategy(prev_op_stra, prev_op_output_index_);
|
||||
|
@ -438,15 +466,11 @@ int64_t Edge::GetReshapeSWCIndexByPrevOpStrategy(const StrategyPtr &prev_op_stra
|
|||
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(next_op_);
|
||||
// First, try to find the zero communication strategy.
|
||||
auto swc_index = reshape_ptr->GetSWCIndexByInputLayoutWithZeroComm(reshape_input_lyt);
|
||||
if (swc_index == -1 && curr_depth == 0) {
|
||||
const auto &next_edges = reshape_ptr->succ_edges();
|
||||
if (!next_edges.empty()) {
|
||||
auto next_edge = next_edges[0];
|
||||
if (configured_ops.find(next_edge->next_operator()) != configured_ops.end()) {
|
||||
// Here, it is sure that Reshape's previous and next operators are both configured strategies,
|
||||
// thus, it is OK that communication happens here.
|
||||
swc_index = reshape_ptr->GetSWCIndexByInputLayout(reshape_input_lyt);
|
||||
}
|
||||
if (swc_index == -1) {
|
||||
// Second, if there is no zero communication strategy, find the strategy with minimum cost.
|
||||
swc_index = reshape_ptr->GetSWCIndexByInputLayoutWithMiniComm(reshape_input_lyt);
|
||||
if (swc_index != -1) {
|
||||
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge_name();
|
||||
}
|
||||
}
|
||||
if (swc_index == -1) {
|
||||
|
@ -456,10 +480,10 @@ int64_t Edge::GetReshapeSWCIndexByPrevOpStrategy(const StrategyPtr &prev_op_stra
|
|||
}
|
||||
|
||||
StrategyPtr Edge::GetPrevOpStrategyByReshapeSWCIndex(int64_t swc_index) {
|
||||
if (next_op_->name().find(RESHAPEINFO) == std::string::npos) {
|
||||
if (!next_op_->IsReshape()) {
|
||||
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s next_op is not a Reshape.";
|
||||
}
|
||||
if (prev_op_->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
if (prev_op_->IsReshape()) {
|
||||
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently.";
|
||||
}
|
||||
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(next_op_);
|
||||
|
@ -472,10 +496,10 @@ StrategyPtr Edge::GetPrevOpStrategyByReshapeSWCIndex(int64_t swc_index) {
|
|||
}
|
||||
|
||||
StrategyPtr Edge::GetNextOpStrategyByReshapeSWCIndex(int64_t swc_index) {
|
||||
if (prev_op_->name().find(RESHAPEINFO) == std::string::npos) {
|
||||
if (!prev_op_->IsReshape()) {
|
||||
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s next_op is not a Reshape.";
|
||||
}
|
||||
if (next_op_->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
if (next_op_->IsReshape()) {
|
||||
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently.";
|
||||
}
|
||||
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(prev_op_);
|
||||
|
@ -487,25 +511,18 @@ StrategyPtr Edge::GetNextOpStrategyByReshapeSWCIndex(int64_t swc_index) {
|
|||
return stra;
|
||||
}
|
||||
|
||||
bool Edge::CheckStrategyConsistency(const std::map<OperatorInfoPtr, StrategyPtr> &configured_ops) {
|
||||
auto prev_stra = prev_op_->selected_strategy();
|
||||
auto next_stra = next_op_->selected_strategy();
|
||||
bool Edge::CheckStrategyConsistency(StrategyPtr prev_stra, StrategyPtr next_stra) {
|
||||
if (prev_stra == nullptr) {
|
||||
MS_LOG(EXCEPTION) << prev_op_->name() << "'s selected strategy is null!";
|
||||
}
|
||||
if (next_op_ == nullptr) {
|
||||
if (next_stra == nullptr) {
|
||||
MS_LOG(EXCEPTION) << next_op_->name() << "'s selected strategy is null!";
|
||||
}
|
||||
auto cost = GetCostByStrategyPair({prev_stra, next_stra});
|
||||
if ((configured_ops.find(prev_op_) == configured_ops.end() ||
|
||||
configured_ops.find(next_op_) == configured_ops.end()) &&
|
||||
(cost == nullptr || cost->communication_cost_ > 0.0)) {
|
||||
PrintStrategy(prev_op_->selected_strategy());
|
||||
PrintStrategy(next_op_->selected_strategy());
|
||||
MS_LOG(ERROR) << "There are redistribution cost occurs at edge: " << edge_name()
|
||||
<< ", consider configuring sharding strategies for two operators."
|
||||
<< " The full name of these two operators are: " << prev_op_->cnode()->fullname_with_scope()
|
||||
<< " and " << next_op_->cnode()->fullname_with_scope();
|
||||
if (cost == nullptr || cost->communication_cost_ > 0.0) {
|
||||
PrintStrategy(next_stra);
|
||||
PrintStrategy(next_stra);
|
||||
MS_LOG(WARNING) << "There are redistribution cost occurs at edge: " << edge_name() << ".";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -83,15 +83,13 @@ class Edge {
|
|||
std::map<CostPtrKey, CostPtrList> GetCostMap() { return cost_map_; }
|
||||
CostPtr GetCostByStrategyPair(const CostPtrKey &);
|
||||
|
||||
StrategyPtr GetNextOpStrategyByPrevOpStrategyWithZeroComm(const StrategyPtr &);
|
||||
StrategyPtr GetPrevOpStrategyByNextOpStrategyWithZeroComm(const StrategyPtr &);
|
||||
int64_t GetReshapeSWCIndexByNextOpStrategy(const StrategyPtr &next_op_stra, int64_t curr_depth,
|
||||
const std::map<OperatorInfoPtr, StrategyPtr> &configured_ops);
|
||||
int64_t GetReshapeSWCIndexByPrevOpStrategy(const StrategyPtr &prev_op_stra, int64_t curr_depth,
|
||||
const std::map<OperatorInfoPtr, StrategyPtr> &configured_ops);
|
||||
StrategyPtr GetNextOpStrategyByPrevOpStrategyWithMiniComm(const StrategyPtr &);
|
||||
StrategyPtr GetPrevOpStrategyByNextOpStrategyWithMiniComm(const StrategyPtr &);
|
||||
int64_t GetReshapeSWCIndexByNextOpStrategy(const StrategyPtr &next_op_stra);
|
||||
int64_t GetReshapeSWCIndexByPrevOpStrategy(const StrategyPtr &prev_op_stra);
|
||||
StrategyPtr GetPrevOpStrategyByReshapeSWCIndex(int64_t swc_index);
|
||||
StrategyPtr GetNextOpStrategyByReshapeSWCIndex(int64_t swc_index);
|
||||
bool CheckStrategyConsistency(const std::map<OperatorInfoPtr, StrategyPtr> &configured_ops);
|
||||
bool CheckStrategyConsistency(StrategyPtr, StrategyPtr);
|
||||
|
||||
void SetCostMapAndInputOutput(std::map<CostPtrKey, CostPtrList> &);
|
||||
// For two operators u--->v, given the output tensor layout of u,
|
||||
|
|
|
@ -101,82 +101,71 @@ void CostGraph::StrategyPropagate(const std::map<OperatorInfoPtr, StrategyPtr> &
|
|||
}
|
||||
}
|
||||
|
||||
void CheckVisitedEdgeConsistency(const EdgePtr &edge, std::map<OperatorInfoPtr, StrategyPtr> configured_ops) {
|
||||
void CheckVisitedEdgeConsistency(const EdgePtr &edge) {
|
||||
auto prev_op = edge->prev_operator();
|
||||
auto next_op = edge->next_operator();
|
||||
if (prev_op->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
if (prev_op->IsReshape()) {
|
||||
const auto &reshape_output_lyt =
|
||||
next_op->GetInputLayoutFromSWCByStrategy(next_op->selected_strategy(), edge->next_op_input_index());
|
||||
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(prev_op);
|
||||
auto consistency =
|
||||
reshape_ptr->CheckStrategyConsistencyByOutputLayout(reshape_ptr->swc_index(), reshape_output_lyt);
|
||||
if (!consistency) {
|
||||
MS_LOG(EXCEPTION) << "Inconsistency occurred at edge: " << edge->edge_name();
|
||||
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
|
||||
}
|
||||
} else if (next_op->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
} else if (next_op->IsReshape()) {
|
||||
const auto &reshape_input_lyt =
|
||||
prev_op->GetOutputLayoutFromSWCByStrategy(prev_op->selected_strategy(), edge->prev_op_output_index());
|
||||
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(next_op);
|
||||
auto consistency = reshape_ptr->CheckStrategyConsistencyByInputLayout(reshape_ptr->swc_index(), reshape_input_lyt);
|
||||
if (!consistency) {
|
||||
MS_LOG(EXCEPTION) << "Inconsistency occurred at edge: " << edge->edge_name();
|
||||
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
|
||||
}
|
||||
} else {
|
||||
auto consistency = edge->CheckStrategyConsistency(configured_ops);
|
||||
auto consistency = edge->CheckStrategyConsistency(prev_op->selected_strategy(), next_op->selected_strategy());
|
||||
if (!consistency) {
|
||||
MS_LOG(EXCEPTION) << "Inconsistency occurred at edge: " << edge->edge_name();
|
||||
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CheckConfiguredSuccEdgeConsistency(const EdgePtr edge, std::map<OperatorInfoPtr, StrategyPtr> configured_ops,
|
||||
int64_t curr_depth) {
|
||||
void CheckConfiguredSuccEdgeConsistency(const EdgePtr edge, std::map<OperatorInfoPtr, StrategyPtr> configured_ops) {
|
||||
auto curr_op = edge->prev_operator();
|
||||
auto next_op = edge->next_operator();
|
||||
if ((curr_op->name().find(RESHAPEINFO) != std::string::npos) && curr_depth > 1) {
|
||||
auto next_op_conf_stra = configured_ops[next_op];
|
||||
auto next_op_conf_stra = configured_ops[next_op];
|
||||
if (curr_op->IsReshape()) {
|
||||
const auto &reshape_output_lyt =
|
||||
next_op->GetInputLayoutFromSWCByStrategy(next_op_conf_stra, edge->next_op_input_index());
|
||||
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(curr_op);
|
||||
auto consistency =
|
||||
reshape_ptr->CheckStrategyConsistencyByOutputLayout(reshape_ptr->swc_index(), reshape_output_lyt);
|
||||
if (!consistency) {
|
||||
MS_LOG(EXCEPTION) << "Inconsistency occurred at edge: " << edge->edge_name();
|
||||
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
|
||||
}
|
||||
} else if (curr_op->name().find(RESHAPEINFO) == std::string::npos) {
|
||||
const auto &next_op_conf_stra = configured_ops[next_op];
|
||||
const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithZeroComm(curr_op->selected_strategy());
|
||||
if ((next_op_conf_stra == nullptr) || (!next_op_conf_stra->IsEqual(next_op_stra))) {
|
||||
MS_LOG(EXCEPTION) << "Sharding strategies should be configured on the boundary operators. "
|
||||
<< "Currently reaching " << curr_op->name() << " and " << next_op->name() << "."
|
||||
<< " The full name of these two operators are: " << curr_op->cnode()->fullname_with_scope()
|
||||
<< " and " << next_op->cnode()->fullname_with_scope();
|
||||
} else {
|
||||
auto consistency = edge->CheckStrategyConsistency(curr_op->selected_strategy(), next_op_conf_stra);
|
||||
if (!consistency) {
|
||||
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CheckConfiguredPrevEdgeConsistency(const EdgePtr edge, std::map<OperatorInfoPtr, StrategyPtr> configured_ops,
|
||||
int64_t curr_depth) {
|
||||
void CheckConfiguredPrevEdgeConsistency(const EdgePtr edge, std::map<OperatorInfoPtr, StrategyPtr> configured_ops) {
|
||||
auto curr_op = edge->next_operator();
|
||||
auto prev_op = edge->prev_operator();
|
||||
if (curr_op->name().find(RESHAPEINFO) != std::string::npos && curr_depth > 1) {
|
||||
auto prev_op_conf_stra = configured_ops[prev_op];
|
||||
auto prev_op_conf_stra = configured_ops[prev_op];
|
||||
if (curr_op->IsReshape()) {
|
||||
const auto &reshape_input_lyt =
|
||||
prev_op->GetOutputLayoutFromSWCByStrategy(prev_op_conf_stra, edge->prev_op_output_index());
|
||||
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(curr_op);
|
||||
auto consistency = reshape_ptr->CheckStrategyConsistencyByInputLayout(reshape_ptr->swc_index(), reshape_input_lyt);
|
||||
if (!consistency) {
|
||||
MS_LOG(EXCEPTION) << "Inconsistency occurred at edge: " << edge->edge_name();
|
||||
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
|
||||
}
|
||||
} else if (curr_op->name().find(RESHAPEINFO) == std::string::npos) {
|
||||
const auto &prev_op_conf_stra = configured_ops[prev_op];
|
||||
const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithZeroComm(curr_op->selected_strategy());
|
||||
if ((prev_op_conf_stra == nullptr) || (!prev_op_conf_stra->IsEqual(prev_op_stra))) {
|
||||
MS_LOG(ERROR) << "curr_depth: " << curr_depth;
|
||||
MS_LOG(EXCEPTION) << "Sharding strategies should be configured on the boundary operators. "
|
||||
<< "Currently reaching " << prev_op->name() << " and " << curr_op->name() << "."
|
||||
<< " The full name of these two operators are: " << prev_op->cnode()->fullname_with_scope()
|
||||
<< " and " << curr_op->cnode()->fullname_with_scope();
|
||||
} else {
|
||||
auto consistency = edge->CheckStrategyConsistency(prev_op_conf_stra, curr_op->selected_strategy());
|
||||
if (!consistency) {
|
||||
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -192,7 +181,7 @@ void CostGraph::BFS(const OperatorInfoPtr &op, const StrategyPtr &op_stra,
|
|||
auto curr_depth = next_level.front().second;
|
||||
visited->at(curr_op) = true;
|
||||
MS_LOG(INFO) << "curr_depth: " << curr_depth;
|
||||
if (curr_op->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
if (curr_op->IsReshape()) {
|
||||
curr_op->set_swc_index(configured_stra_index, curr_depth);
|
||||
} else {
|
||||
curr_op->SetSelectedStrategy(configured_stra, curr_depth);
|
||||
|
@ -200,24 +189,23 @@ void CostGraph::BFS(const OperatorInfoPtr &op, const StrategyPtr &op_stra,
|
|||
for (auto &edge : curr_op->succ_edges()) {
|
||||
const auto &next_op = edge->next_operator();
|
||||
if (visited->at(next_op)) {
|
||||
CheckVisitedEdgeConsistency(edge, configured_ops);
|
||||
CheckVisitedEdgeConsistency(edge);
|
||||
continue;
|
||||
}
|
||||
if ((curr_depth > 0) && (configured_ops.find(next_op) != configured_ops.end())) {
|
||||
CheckConfiguredSuccEdgeConsistency(edge, configured_ops, curr_depth);
|
||||
CheckConfiguredSuccEdgeConsistency(edge, configured_ops);
|
||||
}
|
||||
if (configured_ops.find(next_op) != configured_ops.end()) {
|
||||
continue;
|
||||
}
|
||||
if (curr_op->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
if (curr_op->IsReshape()) {
|
||||
auto stra = edge->GetNextOpStrategyByReshapeSWCIndex(curr_op->swc_index());
|
||||
(void)next_level.emplace(std::make_pair(next_op, std::make_pair(stra, -1)), curr_depth + 1);
|
||||
} else if (next_op->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
auto swc_index =
|
||||
edge->GetReshapeSWCIndexByPrevOpStrategy(curr_op->selected_strategy(), curr_depth, configured_ops);
|
||||
} else if (next_op->IsReshape()) {
|
||||
auto swc_index = edge->GetReshapeSWCIndexByPrevOpStrategy(curr_op->selected_strategy());
|
||||
(void)next_level.emplace(std::make_pair(next_op, std::make_pair(nullptr, swc_index)), curr_depth + 1);
|
||||
} else {
|
||||
const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithZeroComm(curr_op->selected_strategy());
|
||||
const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithMiniComm(curr_op->selected_strategy());
|
||||
if (next_op_stra == nullptr) {
|
||||
PrintStrategy(curr_op->selected_strategy());
|
||||
MS_LOG(EXCEPTION) << next_op->name() << "'s strategy is null in the edge: " << edge->edge_name();
|
||||
|
@ -228,24 +216,23 @@ void CostGraph::BFS(const OperatorInfoPtr &op, const StrategyPtr &op_stra,
|
|||
for (auto &edge : curr_op->prev_edges()) {
|
||||
const auto &prev_op = edge->prev_operator();
|
||||
if (visited->at(prev_op)) {
|
||||
CheckVisitedEdgeConsistency(edge, configured_ops);
|
||||
CheckVisitedEdgeConsistency(edge);
|
||||
continue;
|
||||
}
|
||||
if ((curr_depth > 0) && (configured_ops.find(prev_op) != configured_ops.end())) {
|
||||
CheckConfiguredPrevEdgeConsistency(edge, configured_ops, curr_depth);
|
||||
CheckConfiguredPrevEdgeConsistency(edge, configured_ops);
|
||||
}
|
||||
if (configured_ops.find(prev_op) != configured_ops.end()) {
|
||||
continue;
|
||||
}
|
||||
if (prev_op->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
auto swc_index =
|
||||
edge->GetReshapeSWCIndexByNextOpStrategy(curr_op->selected_strategy(), curr_depth, configured_ops);
|
||||
if (prev_op->IsReshape()) {
|
||||
auto swc_index = edge->GetReshapeSWCIndexByNextOpStrategy(curr_op->selected_strategy());
|
||||
(void)next_level.emplace(std::make_pair(prev_op, std::make_pair(nullptr, swc_index)), curr_depth + 1);
|
||||
} else if (curr_op->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
} else if (curr_op->IsReshape()) {
|
||||
auto prev_stra = edge->GetPrevOpStrategyByReshapeSWCIndex(curr_op->swc_index());
|
||||
(void)next_level.emplace(std::make_pair(prev_op, std::make_pair(prev_stra, -1)), curr_depth + 1);
|
||||
} else {
|
||||
const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithZeroComm(curr_op->selected_strategy());
|
||||
const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithMiniComm(curr_op->selected_strategy());
|
||||
if (prev_op_stra == nullptr) {
|
||||
PrintStrategy(curr_op->selected_strategy());
|
||||
MS_LOG(EXCEPTION) << prev_op->name() << "'s strategy is null in the edge: " << edge->edge_name();
|
||||
|
@ -1644,7 +1631,7 @@ size_t CostGraph::GetNumEdges() const {
|
|||
Status CostGraph::InitReshapeStrategy() {
|
||||
// reshape init should be apply after the init of it's previous node and next node.
|
||||
for (size_t i = 0; i < ops_.size(); ++i) {
|
||||
if (ops_[i]->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
if (ops_[i]->IsReshape()) {
|
||||
auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(ops_[i]);
|
||||
auto in_edges = GetOriginalPrevEdges(ops_[i]);
|
||||
auto pre_iter = std::find_if(in_edges.begin(), in_edges.end(), [&](const std::shared_ptr<Edge> &edge) {
|
||||
|
@ -1699,7 +1686,7 @@ Status CostGraph::InitReshapeStrategy() {
|
|||
Status CostGraph::InitSelectedStrategy() {
|
||||
for (auto &op : ops_) {
|
||||
MS_EXCEPTION_IF_NULL(op);
|
||||
if (op->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
if (op->IsReshape()) {
|
||||
continue;
|
||||
}
|
||||
auto result_op = op->InitSelectedStrategy(op->selected_strategy());
|
||||
|
|
|
@ -1421,6 +1421,13 @@ StrategyPtr OperatorInfo::GetStrategyFromSWCByOutputLayout(TensorLayout output_l
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
bool OperatorInfo::IsReshape() {
|
||||
if (name_.find(RESHAPEINFO) != std::string::npos) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Keep at most (1.0 / epsilon) number of available strategies for each operator.
|
||||
void OperatorInfo::ApproximateStrategies() {
|
||||
auto enable_approxi = CostModelContext::GetInstance()->dp_algo_enable_approxi();
|
||||
|
|
|
@ -150,6 +150,7 @@ class OperatorInfo {
|
|||
TensorLayout GetOutputLayoutFromSWCByStrategy(StrategyPtr stra, size_t output_index);
|
||||
StrategyPtr GetStrategyFromSWCByInputLayout(TensorLayout input_layout, size_t input_index);
|
||||
StrategyPtr GetStrategyFromSWCByOutputLayout(TensorLayout output_layout, size_t output_index);
|
||||
bool IsReshape();
|
||||
|
||||
void set_swc_index(int64_t, int64_t);
|
||||
int64_t swc_index() { return swc_index_; }
|
||||
|
|
|
@ -523,25 +523,25 @@ int64_t ReshapeInfo::GetSWCIndexByOutputLayoutWithZeroComm(const TensorLayout &o
|
|||
return index_computation[0].first;
|
||||
}
|
||||
|
||||
int64_t ReshapeInfo::GetSWCIndexByOutputLayout(const TensorLayout &output_layout) {
|
||||
std::vector<std::pair<int64_t, double>> index_computation;
|
||||
int64_t ReshapeInfo::GetSWCIndexByOutputLayoutWithMiniComm(const TensorLayout &output_layout) {
|
||||
std::vector<std::pair<int64_t, double>> index_comm;
|
||||
for (size_t i = 0; i < strategy_cost_.size(); ++i) {
|
||||
const auto &swc = strategy_cost_[i];
|
||||
if (swc->outputs_ptr[0].tensor_layout() == output_layout) {
|
||||
(void)index_computation.emplace_back(SizeToLong(i), swc->cost_list[0]->computation_cost_);
|
||||
(void)index_comm.emplace_back(SizeToLong(i), swc->cost_list[0]->communication_without_parameter_);
|
||||
}
|
||||
}
|
||||
if (index_computation.empty()) {
|
||||
if (index_comm.empty()) {
|
||||
MS_LOG(ERROR) << "There in no available strategy for zero communication cost for reshape: " << name();
|
||||
return -1;
|
||||
}
|
||||
if (index_computation.size() > 1) {
|
||||
if (index_comm.size() > 1) {
|
||||
MS_LOG(INFO) << "There are multiple strategies available for reshape: " << name();
|
||||
}
|
||||
std::sort(
|
||||
index_computation.begin(), index_computation.end(),
|
||||
index_comm.begin(), index_comm.end(),
|
||||
[](const std::pair<size_t, double> &a, const std::pair<size_t, double> &b) { return a.second <= b.second; });
|
||||
return index_computation[0].first;
|
||||
return index_comm[0].first;
|
||||
}
|
||||
|
||||
int64_t ReshapeInfo::GetSWCIndexByInputLayoutWithZeroComm(const TensorLayout &input_layout) {
|
||||
|
@ -566,25 +566,25 @@ int64_t ReshapeInfo::GetSWCIndexByInputLayoutWithZeroComm(const TensorLayout &in
|
|||
return index_computation[0].first;
|
||||
}
|
||||
|
||||
int64_t ReshapeInfo::GetSWCIndexByInputLayout(const TensorLayout &input_layout) {
|
||||
std::vector<std::pair<int64_t, double>> index_computation;
|
||||
int64_t ReshapeInfo::GetSWCIndexByInputLayoutWithMiniComm(const TensorLayout &input_layout) {
|
||||
std::vector<std::pair<int64_t, double>> index_comm;
|
||||
for (size_t i = 0; i < strategy_cost_.size(); ++i) {
|
||||
const auto &swc = strategy_cost_[i];
|
||||
if (swc->inputs_ptr[0].tensor_layout() == input_layout) {
|
||||
(void)index_computation.emplace_back(SizeToLong(i), swc->cost_list[0]->computation_cost_);
|
||||
(void)index_comm.emplace_back(SizeToLong(i), swc->cost_list[0]->communication_without_parameter_);
|
||||
}
|
||||
}
|
||||
if (index_computation.empty()) {
|
||||
if (index_comm.empty()) {
|
||||
MS_LOG(ERROR) << "There in no available strategy for zero communication cost for reshape: " << name();
|
||||
return -1;
|
||||
}
|
||||
if (index_computation.size() > 1) {
|
||||
if (index_comm.size() > 1) {
|
||||
MS_LOG(INFO) << "There are multiple strategies available for reshape: " << name();
|
||||
}
|
||||
std::sort(
|
||||
index_computation.begin(), index_computation.end(),
|
||||
index_comm.begin(), index_comm.end(),
|
||||
[](const std::pair<size_t, double> &a, const std::pair<size_t, double> &b) { return a.second <= b.second; });
|
||||
return index_computation[0].first;
|
||||
return index_comm[0].first;
|
||||
}
|
||||
|
||||
bool ReshapeInfo::CheckStrategyConsistencyByOutputLayout(int64_t swc_index, const TensorLayout &output_layout) {
|
||||
|
@ -593,7 +593,13 @@ bool ReshapeInfo::CheckStrategyConsistencyByOutputLayout(int64_t swc_index, cons
|
|||
return false;
|
||||
}
|
||||
const auto &swc = strategy_cost_[swc_index];
|
||||
return swc->outputs_ptr[0].tensor_layout() == output_layout;
|
||||
if (swc->outputs_ptr[0].tensor_layout() == output_layout) {
|
||||
return true;
|
||||
}
|
||||
MS_LOG(WARNING) << name_ << "'s desired output layout is: " << output_layout.ToString() << ", while the selected "
|
||||
<< "output layout is: " << swc->outputs_ptr[0].tensor_layout().ToString()
|
||||
<< " and the input layout is: " << swc->inputs_ptr[0].tensor_layout().ToString();
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ReshapeInfo::CheckStrategyConsistencyByInputLayout(int64_t swc_index, const TensorLayout &input_layout) {
|
||||
|
@ -602,7 +608,13 @@ bool ReshapeInfo::CheckStrategyConsistencyByInputLayout(int64_t swc_index, const
|
|||
return false;
|
||||
}
|
||||
const auto &swc = strategy_cost_[swc_index];
|
||||
return swc->inputs_ptr[0].tensor_layout() == input_layout;
|
||||
if (swc->inputs_ptr[0].tensor_layout() == input_layout) {
|
||||
return true;
|
||||
}
|
||||
MS_LOG(WARNING) << name_ << "'s desired input layout is:" << input_layout.ToString() << ", while the selected "
|
||||
<< "input layout is: " << swc->inputs_ptr[0].tensor_layout().ToString()
|
||||
<< " and the output layout is: " << swc->outputs_ptr[0].tensor_layout().ToString();
|
||||
return false;
|
||||
}
|
||||
|
||||
TensorLayout ReshapeInfo::GetInputLayoutBySWCIndex(int64_t swc_index) {
|
||||
|
|
|
@ -70,9 +70,9 @@ class ReshapeInfo : public OperatorInfo {
|
|||
int64_t next_operator_index() const { return next_operator_index_; }
|
||||
|
||||
int64_t GetSWCIndexByOutputLayoutWithZeroComm(const TensorLayout &);
|
||||
int64_t GetSWCIndexByOutputLayout(const TensorLayout &);
|
||||
int64_t GetSWCIndexByOutputLayoutWithMiniComm(const TensorLayout &);
|
||||
int64_t GetSWCIndexByInputLayoutWithZeroComm(const TensorLayout &);
|
||||
int64_t GetSWCIndexByInputLayout(const TensorLayout &);
|
||||
int64_t GetSWCIndexByInputLayoutWithMiniComm(const TensorLayout &);
|
||||
bool CheckStrategyConsistencyByOutputLayout(int64_t, const TensorLayout &);
|
||||
bool CheckStrategyConsistencyByInputLayout(int64_t, const TensorLayout &);
|
||||
|
||||
|
|
|
@ -647,7 +647,7 @@ void CreateEdgeBetweenTwoOps(const OperatorInfoPtr &prev_op_info, const Operator
|
|||
if (ParallelContext::GetInstance()->sharding_propagation() && (prev_prim->name() == CAST) &&
|
||||
(configured_stra_ops_.find(node_op_info) != configured_stra_ops_.end())) {
|
||||
const auto next_op_stra = configured_stra_ops_[node_op_info];
|
||||
const auto cast_stra = edge_ptr->GetPrevOpStrategyByNextOpStrategyWithZeroComm(next_op_stra);
|
||||
const auto cast_stra = edge_ptr->GetPrevOpStrategyByNextOpStrategyWithMiniComm(next_op_stra);
|
||||
if (cast_stra == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "No available strategy for: " << prev_op_info->name();
|
||||
}
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
|
@ -83,5 +82,4 @@ def test_auto_parallel_activation4():
|
|||
strategy2 = None
|
||||
strategy3 = ((8, 2),)
|
||||
net = Net(_w1, strategy1, strategy2, strategy3)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
compile_net(net)
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
|
@ -59,8 +58,7 @@ def test_auto_parallel_activation1():
|
|||
strategy2 = ((8, 1),)
|
||||
strategy3 = ((1, 8), (1, 1))
|
||||
net = Net(_w1, strategy1, strategy2, strategy3)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
compile_net(net)
|
||||
|
||||
def test_auto_parallel_activation2():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0,
|
||||
|
|
Loading…
Reference in New Issue