!26494 [Auto parallel] Adjusting sharding propagation

Merge pull request !26494 from Xiaoda/102-adjusting-sharding-propagation
This commit is contained in:
i-robot 2021-11-19 02:00:46 +00:00 committed by Gitee
commit 7559d5b798
10 changed files with 156 additions and 138 deletions

View File

@ -349,8 +349,9 @@ CostPtr Edge::GetCostByStrategyPair(const CostPtrKey &stra_pair) {
return cost_vec[0];
}
StrategyPtr Edge::GetNextOpStrategyByPrevOpStrategyWithZeroComm(const StrategyPtr &prev_op_stra) {
StrategyPtr Edge::GetNextOpStrategyByPrevOpStrategyWithMiniComm(const StrategyPtr &prev_op_stra) {
std::vector<std::pair<StrategyPtr, double>> next_op_stras;
// First, try to find the strategy with zero communication cost.
for (auto &key_value : cost_map_) {
const auto &candidate_prev_op_stra = key_value.first.first;
if (prev_op_stra->IsEqual(candidate_prev_op_stra) && (key_value.second[0]->communication_cost_ == 0.0)) {
@ -358,12 +359,28 @@ StrategyPtr Edge::GetNextOpStrategyByPrevOpStrategyWithZeroComm(const StrategyPt
}
}
if (next_op_stras.empty()) {
MS_LOG(ERROR) << "There are no available strategy for zero communication cost for edge: " << edge_name_;
return nullptr;
} else if (next_op_stras.size() > 1) {
// Second, if there is not strategy with zero communication cost, find the one with minimum communication cost.
std::vector<std::pair<StrategyPtr, double>> next_stras;
for (auto &key_value : cost_map_) {
const auto &candidate_prev_op_stra = key_value.first.first;
if (prev_op_stra->IsEqual(candidate_prev_op_stra)) {
(void)next_stras.emplace_back(key_value.first.second, key_value.second[0]->communication_cost_);
}
}
if (next_stras.empty()) {
MS_LOG(ERROR) << "There are no available strategy for zero communication cost for edge: " << edge_name_;
return nullptr;
}
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge_name();
std::sort(next_stras.begin(), next_stras.end(),
[](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
return a.second <= b.second;
});
return next_stras[0].first;
}
if (next_op_stras.size() > 1) {
MS_LOG(INFO) << "There are multiple strategies for edge: " << edge_name_
<< ", choose the one with"
" minimum computation costs.";
<< " with zero communication cost, choose the one with minimum computation costs.";
}
std::sort(next_op_stras.begin(), next_op_stras.end(),
[](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
@ -372,8 +389,9 @@ StrategyPtr Edge::GetNextOpStrategyByPrevOpStrategyWithZeroComm(const StrategyPt
return next_op_stras[0].first;
}
StrategyPtr Edge::GetPrevOpStrategyByNextOpStrategyWithZeroComm(const StrategyPtr &next_op_stra) {
StrategyPtr Edge::GetPrevOpStrategyByNextOpStrategyWithMiniComm(const StrategyPtr &next_op_stra) {
std::vector<std::pair<StrategyPtr, double>> prev_op_stras;
// First, try to find the strategy with zero communication cost.
for (auto &key_value : cost_map_) {
const auto &candidate_next_op_stra = key_value.first.second;
if (next_op_stra->IsEqual(candidate_next_op_stra) && (key_value.second[0]->communication_cost_ == 0.0)) {
@ -381,12 +399,28 @@ StrategyPtr Edge::GetPrevOpStrategyByNextOpStrategyWithZeroComm(const StrategyPt
}
}
if (prev_op_stras.empty()) {
MS_LOG(ERROR) << "There are no available strategy for zero communication cost for edge: " << edge_name_;
return nullptr;
} else if (prev_op_stras.size() > 1) {
// Second, if there is no strategy with zero communication cost, find the one with minimum communication cost.
std::vector<std::pair<StrategyPtr, double>> prev_stras;
for (auto &key_value : cost_map_) {
const auto &candidate_next_op_stra = key_value.first.second;
if (next_op_stra->IsEqual(candidate_next_op_stra)) {
(void)prev_stras.emplace_back(key_value.first.first, key_value.second[0]->communication_cost_);
}
}
if (prev_stras.empty()) {
MS_LOG(ERROR) << "There are no available strategy for zero communication cost for edge: " << edge_name_;
return nullptr;
}
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge_name();
std::sort(prev_stras.begin(), prev_stras.end(),
[](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
return a.second <= b.second;
});
return prev_stras[0].first;
}
if (prev_op_stras.size() > 1) {
MS_LOG(INFO) << "There are multiple strategies for edge: " << edge_name_
<< ", choose the one with minimum "
"computation costs.";
<< " with zero communication costs, choose the one with minimum computation costs.";
}
std::sort(prev_op_stras.begin(), prev_op_stras.end(),
[](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
@ -395,12 +429,11 @@ StrategyPtr Edge::GetPrevOpStrategyByNextOpStrategyWithZeroComm(const StrategyPt
return prev_op_stras[0].first;
}
int64_t Edge::GetReshapeSWCIndexByNextOpStrategy(const StrategyPtr &next_op_stra, int64_t curr_depth,
const std::map<OperatorInfoPtr, StrategyPtr> &configured_ops) {
if (prev_op_->name().find(RESHAPEINFO) == std::string::npos) {
int64_t Edge::GetReshapeSWCIndexByNextOpStrategy(const StrategyPtr &next_op_stra) {
if (!prev_op_->IsReshape()) {
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s prev_op is not a Reshape.";
}
if (next_op_->name().find(RESHAPEINFO) != std::string::npos) {
if (next_op_->IsReshape()) {
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently.";
}
const auto &reshape_output_layout = next_op_->GetInputLayoutFromSWCByStrategy(next_op_stra, next_op_input_index_);
@ -408,15 +441,11 @@ int64_t Edge::GetReshapeSWCIndexByNextOpStrategy(const StrategyPtr &next_op_stra
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(prev_op_);
// First, try to find the zero communication strategy.
auto swc_index = reshape_ptr->GetSWCIndexByOutputLayoutWithZeroComm(reshape_output_layout);
if (swc_index == -1 && curr_depth == 0) {
const auto &prev_edges = reshape_ptr->prev_edges();
if (!prev_edges.empty()) {
auto prev_edge = prev_edges[0];
if (configured_ops.find(prev_edge->prev_operator()) != configured_ops.end()) {
// Here, it is sure that Reshape's previous and next operators are both configured strategies,
// thus, it is OK that communication happens here.
swc_index = reshape_ptr->GetSWCIndexByOutputLayout(reshape_output_layout);
}
if (swc_index == -1) {
// Second, if there is no strategy with zero communication cost, find the strategy with minimum cost.
swc_index = reshape_ptr->GetSWCIndexByOutputLayoutWithMiniComm(reshape_output_layout);
if (swc_index != -1) {
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge_name();
}
}
if (swc_index == -1) {
@ -425,12 +454,11 @@ int64_t Edge::GetReshapeSWCIndexByNextOpStrategy(const StrategyPtr &next_op_stra
return swc_index;
}
int64_t Edge::GetReshapeSWCIndexByPrevOpStrategy(const StrategyPtr &prev_op_stra, int64_t curr_depth,
const std::map<OperatorInfoPtr, StrategyPtr> &configured_ops) {
if (next_op_->name().find(RESHAPEINFO) == std::string::npos) {
int64_t Edge::GetReshapeSWCIndexByPrevOpStrategy(const StrategyPtr &prev_op_stra) {
if (!next_op_->IsReshape()) {
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s next_op is not a Reshape.";
}
if (prev_op_->name().find(RESHAPEINFO) != std::string::npos) {
if (prev_op_->IsReshape()) {
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently.";
}
const auto &reshape_input_lyt = prev_op_->GetOutputLayoutFromSWCByStrategy(prev_op_stra, prev_op_output_index_);
@ -438,15 +466,11 @@ int64_t Edge::GetReshapeSWCIndexByPrevOpStrategy(const StrategyPtr &prev_op_stra
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(next_op_);
// First, try to find the zero communication strategy.
auto swc_index = reshape_ptr->GetSWCIndexByInputLayoutWithZeroComm(reshape_input_lyt);
if (swc_index == -1 && curr_depth == 0) {
const auto &next_edges = reshape_ptr->succ_edges();
if (!next_edges.empty()) {
auto next_edge = next_edges[0];
if (configured_ops.find(next_edge->next_operator()) != configured_ops.end()) {
// Here, it is sure that Reshape's previous and next operators are both configured strategies,
// thus, it is OK that communication happens here.
swc_index = reshape_ptr->GetSWCIndexByInputLayout(reshape_input_lyt);
}
if (swc_index == -1) {
// Second, if there is no zero communication strategy, find the strategy with minimum cost.
swc_index = reshape_ptr->GetSWCIndexByInputLayoutWithMiniComm(reshape_input_lyt);
if (swc_index != -1) {
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge_name();
}
}
if (swc_index == -1) {
@ -456,10 +480,10 @@ int64_t Edge::GetReshapeSWCIndexByPrevOpStrategy(const StrategyPtr &prev_op_stra
}
StrategyPtr Edge::GetPrevOpStrategyByReshapeSWCIndex(int64_t swc_index) {
if (next_op_->name().find(RESHAPEINFO) == std::string::npos) {
if (!next_op_->IsReshape()) {
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s next_op is not a Reshape.";
}
if (prev_op_->name().find(RESHAPEINFO) != std::string::npos) {
if (prev_op_->IsReshape()) {
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently.";
}
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(next_op_);
@ -472,10 +496,10 @@ StrategyPtr Edge::GetPrevOpStrategyByReshapeSWCIndex(int64_t swc_index) {
}
StrategyPtr Edge::GetNextOpStrategyByReshapeSWCIndex(int64_t swc_index) {
if (prev_op_->name().find(RESHAPEINFO) == std::string::npos) {
if (!prev_op_->IsReshape()) {
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s next_op is not a Reshape.";
}
if (next_op_->name().find(RESHAPEINFO) != std::string::npos) {
if (next_op_->IsReshape()) {
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently.";
}
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(prev_op_);
@ -487,25 +511,18 @@ StrategyPtr Edge::GetNextOpStrategyByReshapeSWCIndex(int64_t swc_index) {
return stra;
}
bool Edge::CheckStrategyConsistency(const std::map<OperatorInfoPtr, StrategyPtr> &configured_ops) {
auto prev_stra = prev_op_->selected_strategy();
auto next_stra = next_op_->selected_strategy();
bool Edge::CheckStrategyConsistency(StrategyPtr prev_stra, StrategyPtr next_stra) {
if (prev_stra == nullptr) {
MS_LOG(EXCEPTION) << prev_op_->name() << "'s selected strategy is null!";
}
if (next_op_ == nullptr) {
if (next_stra == nullptr) {
MS_LOG(EXCEPTION) << next_op_->name() << "'s selected strategy is null!";
}
auto cost = GetCostByStrategyPair({prev_stra, next_stra});
if ((configured_ops.find(prev_op_) == configured_ops.end() ||
configured_ops.find(next_op_) == configured_ops.end()) &&
(cost == nullptr || cost->communication_cost_ > 0.0)) {
PrintStrategy(prev_op_->selected_strategy());
PrintStrategy(next_op_->selected_strategy());
MS_LOG(ERROR) << "There are redistribution cost occurs at edge: " << edge_name()
<< ", consider configuring sharding strategies for two operators."
<< " The full name of these two operators are: " << prev_op_->cnode()->fullname_with_scope()
<< " and " << next_op_->cnode()->fullname_with_scope();
if (cost == nullptr || cost->communication_cost_ > 0.0) {
PrintStrategy(next_stra);
PrintStrategy(next_stra);
MS_LOG(WARNING) << "There are redistribution cost occurs at edge: " << edge_name() << ".";
return false;
}
return true;

View File

@ -83,15 +83,13 @@ class Edge {
std::map<CostPtrKey, CostPtrList> GetCostMap() { return cost_map_; }
CostPtr GetCostByStrategyPair(const CostPtrKey &);
StrategyPtr GetNextOpStrategyByPrevOpStrategyWithZeroComm(const StrategyPtr &);
StrategyPtr GetPrevOpStrategyByNextOpStrategyWithZeroComm(const StrategyPtr &);
int64_t GetReshapeSWCIndexByNextOpStrategy(const StrategyPtr &next_op_stra, int64_t curr_depth,
const std::map<OperatorInfoPtr, StrategyPtr> &configured_ops);
int64_t GetReshapeSWCIndexByPrevOpStrategy(const StrategyPtr &prev_op_stra, int64_t curr_depth,
const std::map<OperatorInfoPtr, StrategyPtr> &configured_ops);
StrategyPtr GetNextOpStrategyByPrevOpStrategyWithMiniComm(const StrategyPtr &);
StrategyPtr GetPrevOpStrategyByNextOpStrategyWithMiniComm(const StrategyPtr &);
int64_t GetReshapeSWCIndexByNextOpStrategy(const StrategyPtr &next_op_stra);
int64_t GetReshapeSWCIndexByPrevOpStrategy(const StrategyPtr &prev_op_stra);
StrategyPtr GetPrevOpStrategyByReshapeSWCIndex(int64_t swc_index);
StrategyPtr GetNextOpStrategyByReshapeSWCIndex(int64_t swc_index);
bool CheckStrategyConsistency(const std::map<OperatorInfoPtr, StrategyPtr> &configured_ops);
bool CheckStrategyConsistency(StrategyPtr, StrategyPtr);
void SetCostMapAndInputOutput(std::map<CostPtrKey, CostPtrList> &);
// For two operators u--->v, given the output tensor layout of u,

View File

@ -101,82 +101,71 @@ void CostGraph::StrategyPropagate(const std::map<OperatorInfoPtr, StrategyPtr> &
}
}
void CheckVisitedEdgeConsistency(const EdgePtr &edge, std::map<OperatorInfoPtr, StrategyPtr> configured_ops) {
void CheckVisitedEdgeConsistency(const EdgePtr &edge) {
auto prev_op = edge->prev_operator();
auto next_op = edge->next_operator();
if (prev_op->name().find(RESHAPEINFO) != std::string::npos) {
if (prev_op->IsReshape()) {
const auto &reshape_output_lyt =
next_op->GetInputLayoutFromSWCByStrategy(next_op->selected_strategy(), edge->next_op_input_index());
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(prev_op);
auto consistency =
reshape_ptr->CheckStrategyConsistencyByOutputLayout(reshape_ptr->swc_index(), reshape_output_lyt);
if (!consistency) {
MS_LOG(EXCEPTION) << "Inconsistency occurred at edge: " << edge->edge_name();
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
}
} else if (next_op->name().find(RESHAPEINFO) != std::string::npos) {
} else if (next_op->IsReshape()) {
const auto &reshape_input_lyt =
prev_op->GetOutputLayoutFromSWCByStrategy(prev_op->selected_strategy(), edge->prev_op_output_index());
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(next_op);
auto consistency = reshape_ptr->CheckStrategyConsistencyByInputLayout(reshape_ptr->swc_index(), reshape_input_lyt);
if (!consistency) {
MS_LOG(EXCEPTION) << "Inconsistency occurred at edge: " << edge->edge_name();
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
}
} else {
auto consistency = edge->CheckStrategyConsistency(configured_ops);
auto consistency = edge->CheckStrategyConsistency(prev_op->selected_strategy(), next_op->selected_strategy());
if (!consistency) {
MS_LOG(EXCEPTION) << "Inconsistency occurred at edge: " << edge->edge_name();
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
}
}
}
void CheckConfiguredSuccEdgeConsistency(const EdgePtr edge, std::map<OperatorInfoPtr, StrategyPtr> configured_ops,
int64_t curr_depth) {
void CheckConfiguredSuccEdgeConsistency(const EdgePtr edge, std::map<OperatorInfoPtr, StrategyPtr> configured_ops) {
auto curr_op = edge->prev_operator();
auto next_op = edge->next_operator();
if ((curr_op->name().find(RESHAPEINFO) != std::string::npos) && curr_depth > 1) {
auto next_op_conf_stra = configured_ops[next_op];
auto next_op_conf_stra = configured_ops[next_op];
if (curr_op->IsReshape()) {
const auto &reshape_output_lyt =
next_op->GetInputLayoutFromSWCByStrategy(next_op_conf_stra, edge->next_op_input_index());
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(curr_op);
auto consistency =
reshape_ptr->CheckStrategyConsistencyByOutputLayout(reshape_ptr->swc_index(), reshape_output_lyt);
if (!consistency) {
MS_LOG(EXCEPTION) << "Inconsistency occurred at edge: " << edge->edge_name();
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
}
} else if (curr_op->name().find(RESHAPEINFO) == std::string::npos) {
const auto &next_op_conf_stra = configured_ops[next_op];
const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithZeroComm(curr_op->selected_strategy());
if ((next_op_conf_stra == nullptr) || (!next_op_conf_stra->IsEqual(next_op_stra))) {
MS_LOG(EXCEPTION) << "Sharding strategies should be configured on the boundary operators. "
<< "Currently reaching " << curr_op->name() << " and " << next_op->name() << "."
<< " The full name of these two operators are: " << curr_op->cnode()->fullname_with_scope()
<< " and " << next_op->cnode()->fullname_with_scope();
} else {
auto consistency = edge->CheckStrategyConsistency(curr_op->selected_strategy(), next_op_conf_stra);
if (!consistency) {
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
}
}
}
void CheckConfiguredPrevEdgeConsistency(const EdgePtr edge, std::map<OperatorInfoPtr, StrategyPtr> configured_ops,
int64_t curr_depth) {
void CheckConfiguredPrevEdgeConsistency(const EdgePtr edge, std::map<OperatorInfoPtr, StrategyPtr> configured_ops) {
auto curr_op = edge->next_operator();
auto prev_op = edge->prev_operator();
if (curr_op->name().find(RESHAPEINFO) != std::string::npos && curr_depth > 1) {
auto prev_op_conf_stra = configured_ops[prev_op];
auto prev_op_conf_stra = configured_ops[prev_op];
if (curr_op->IsReshape()) {
const auto &reshape_input_lyt =
prev_op->GetOutputLayoutFromSWCByStrategy(prev_op_conf_stra, edge->prev_op_output_index());
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(curr_op);
auto consistency = reshape_ptr->CheckStrategyConsistencyByInputLayout(reshape_ptr->swc_index(), reshape_input_lyt);
if (!consistency) {
MS_LOG(EXCEPTION) << "Inconsistency occurred at edge: " << edge->edge_name();
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
}
} else if (curr_op->name().find(RESHAPEINFO) == std::string::npos) {
const auto &prev_op_conf_stra = configured_ops[prev_op];
const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithZeroComm(curr_op->selected_strategy());
if ((prev_op_conf_stra == nullptr) || (!prev_op_conf_stra->IsEqual(prev_op_stra))) {
MS_LOG(ERROR) << "curr_depth: " << curr_depth;
MS_LOG(EXCEPTION) << "Sharding strategies should be configured on the boundary operators. "
<< "Currently reaching " << prev_op->name() << " and " << curr_op->name() << "."
<< " The full name of these two operators are: " << prev_op->cnode()->fullname_with_scope()
<< " and " << curr_op->cnode()->fullname_with_scope();
} else {
auto consistency = edge->CheckStrategyConsistency(prev_op_conf_stra, curr_op->selected_strategy());
if (!consistency) {
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
}
}
}
@ -192,7 +181,7 @@ void CostGraph::BFS(const OperatorInfoPtr &op, const StrategyPtr &op_stra,
auto curr_depth = next_level.front().second;
visited->at(curr_op) = true;
MS_LOG(INFO) << "curr_depth: " << curr_depth;
if (curr_op->name().find(RESHAPEINFO) != std::string::npos) {
if (curr_op->IsReshape()) {
curr_op->set_swc_index(configured_stra_index, curr_depth);
} else {
curr_op->SetSelectedStrategy(configured_stra, curr_depth);
@ -200,24 +189,23 @@ void CostGraph::BFS(const OperatorInfoPtr &op, const StrategyPtr &op_stra,
for (auto &edge : curr_op->succ_edges()) {
const auto &next_op = edge->next_operator();
if (visited->at(next_op)) {
CheckVisitedEdgeConsistency(edge, configured_ops);
CheckVisitedEdgeConsistency(edge);
continue;
}
if ((curr_depth > 0) && (configured_ops.find(next_op) != configured_ops.end())) {
CheckConfiguredSuccEdgeConsistency(edge, configured_ops, curr_depth);
CheckConfiguredSuccEdgeConsistency(edge, configured_ops);
}
if (configured_ops.find(next_op) != configured_ops.end()) {
continue;
}
if (curr_op->name().find(RESHAPEINFO) != std::string::npos) {
if (curr_op->IsReshape()) {
auto stra = edge->GetNextOpStrategyByReshapeSWCIndex(curr_op->swc_index());
(void)next_level.emplace(std::make_pair(next_op, std::make_pair(stra, -1)), curr_depth + 1);
} else if (next_op->name().find(RESHAPEINFO) != std::string::npos) {
auto swc_index =
edge->GetReshapeSWCIndexByPrevOpStrategy(curr_op->selected_strategy(), curr_depth, configured_ops);
} else if (next_op->IsReshape()) {
auto swc_index = edge->GetReshapeSWCIndexByPrevOpStrategy(curr_op->selected_strategy());
(void)next_level.emplace(std::make_pair(next_op, std::make_pair(nullptr, swc_index)), curr_depth + 1);
} else {
const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithZeroComm(curr_op->selected_strategy());
const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithMiniComm(curr_op->selected_strategy());
if (next_op_stra == nullptr) {
PrintStrategy(curr_op->selected_strategy());
MS_LOG(EXCEPTION) << next_op->name() << "'s strategy is null in the edge: " << edge->edge_name();
@ -228,24 +216,23 @@ void CostGraph::BFS(const OperatorInfoPtr &op, const StrategyPtr &op_stra,
for (auto &edge : curr_op->prev_edges()) {
const auto &prev_op = edge->prev_operator();
if (visited->at(prev_op)) {
CheckVisitedEdgeConsistency(edge, configured_ops);
CheckVisitedEdgeConsistency(edge);
continue;
}
if ((curr_depth > 0) && (configured_ops.find(prev_op) != configured_ops.end())) {
CheckConfiguredPrevEdgeConsistency(edge, configured_ops, curr_depth);
CheckConfiguredPrevEdgeConsistency(edge, configured_ops);
}
if (configured_ops.find(prev_op) != configured_ops.end()) {
continue;
}
if (prev_op->name().find(RESHAPEINFO) != std::string::npos) {
auto swc_index =
edge->GetReshapeSWCIndexByNextOpStrategy(curr_op->selected_strategy(), curr_depth, configured_ops);
if (prev_op->IsReshape()) {
auto swc_index = edge->GetReshapeSWCIndexByNextOpStrategy(curr_op->selected_strategy());
(void)next_level.emplace(std::make_pair(prev_op, std::make_pair(nullptr, swc_index)), curr_depth + 1);
} else if (curr_op->name().find(RESHAPEINFO) != std::string::npos) {
} else if (curr_op->IsReshape()) {
auto prev_stra = edge->GetPrevOpStrategyByReshapeSWCIndex(curr_op->swc_index());
(void)next_level.emplace(std::make_pair(prev_op, std::make_pair(prev_stra, -1)), curr_depth + 1);
} else {
const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithZeroComm(curr_op->selected_strategy());
const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithMiniComm(curr_op->selected_strategy());
if (prev_op_stra == nullptr) {
PrintStrategy(curr_op->selected_strategy());
MS_LOG(EXCEPTION) << prev_op->name() << "'s strategy is null in the edge: " << edge->edge_name();
@ -1644,7 +1631,7 @@ size_t CostGraph::GetNumEdges() const {
Status CostGraph::InitReshapeStrategy() {
// reshape init should be apply after the init of it's previous node and next node.
for (size_t i = 0; i < ops_.size(); ++i) {
if (ops_[i]->name().find(RESHAPEINFO) != std::string::npos) {
if (ops_[i]->IsReshape()) {
auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(ops_[i]);
auto in_edges = GetOriginalPrevEdges(ops_[i]);
auto pre_iter = std::find_if(in_edges.begin(), in_edges.end(), [&](const std::shared_ptr<Edge> &edge) {
@ -1699,7 +1686,7 @@ Status CostGraph::InitReshapeStrategy() {
Status CostGraph::InitSelectedStrategy() {
for (auto &op : ops_) {
MS_EXCEPTION_IF_NULL(op);
if (op->name().find(RESHAPEINFO) != std::string::npos) {
if (op->IsReshape()) {
continue;
}
auto result_op = op->InitSelectedStrategy(op->selected_strategy());

View File

@ -1421,6 +1421,13 @@ StrategyPtr OperatorInfo::GetStrategyFromSWCByOutputLayout(TensorLayout output_l
return nullptr;
}
bool OperatorInfo::IsReshape() {
if (name_.find(RESHAPEINFO) != std::string::npos) {
return true;
}
return false;
}
// Keep at most (1.0 / epsilon) number of available strategies for each operator.
void OperatorInfo::ApproximateStrategies() {
auto enable_approxi = CostModelContext::GetInstance()->dp_algo_enable_approxi();

View File

@ -150,6 +150,7 @@ class OperatorInfo {
TensorLayout GetOutputLayoutFromSWCByStrategy(StrategyPtr stra, size_t output_index);
StrategyPtr GetStrategyFromSWCByInputLayout(TensorLayout input_layout, size_t input_index);
StrategyPtr GetStrategyFromSWCByOutputLayout(TensorLayout output_layout, size_t output_index);
bool IsReshape();
void set_swc_index(int64_t, int64_t);
int64_t swc_index() { return swc_index_; }

View File

@ -523,25 +523,25 @@ int64_t ReshapeInfo::GetSWCIndexByOutputLayoutWithZeroComm(const TensorLayout &o
return index_computation[0].first;
}
int64_t ReshapeInfo::GetSWCIndexByOutputLayout(const TensorLayout &output_layout) {
std::vector<std::pair<int64_t, double>> index_computation;
int64_t ReshapeInfo::GetSWCIndexByOutputLayoutWithMiniComm(const TensorLayout &output_layout) {
std::vector<std::pair<int64_t, double>> index_comm;
for (size_t i = 0; i < strategy_cost_.size(); ++i) {
const auto &swc = strategy_cost_[i];
if (swc->outputs_ptr[0].tensor_layout() == output_layout) {
(void)index_computation.emplace_back(SizeToLong(i), swc->cost_list[0]->computation_cost_);
(void)index_comm.emplace_back(SizeToLong(i), swc->cost_list[0]->communication_without_parameter_);
}
}
if (index_computation.empty()) {
if (index_comm.empty()) {
MS_LOG(ERROR) << "There in no available strategy for zero communication cost for reshape: " << name();
return -1;
}
if (index_computation.size() > 1) {
if (index_comm.size() > 1) {
MS_LOG(INFO) << "There are multiple strategies available for reshape: " << name();
}
std::sort(
index_computation.begin(), index_computation.end(),
index_comm.begin(), index_comm.end(),
[](const std::pair<size_t, double> &a, const std::pair<size_t, double> &b) { return a.second <= b.second; });
return index_computation[0].first;
return index_comm[0].first;
}
int64_t ReshapeInfo::GetSWCIndexByInputLayoutWithZeroComm(const TensorLayout &input_layout) {
@ -566,25 +566,25 @@ int64_t ReshapeInfo::GetSWCIndexByInputLayoutWithZeroComm(const TensorLayout &in
return index_computation[0].first;
}
int64_t ReshapeInfo::GetSWCIndexByInputLayout(const TensorLayout &input_layout) {
std::vector<std::pair<int64_t, double>> index_computation;
int64_t ReshapeInfo::GetSWCIndexByInputLayoutWithMiniComm(const TensorLayout &input_layout) {
std::vector<std::pair<int64_t, double>> index_comm;
for (size_t i = 0; i < strategy_cost_.size(); ++i) {
const auto &swc = strategy_cost_[i];
if (swc->inputs_ptr[0].tensor_layout() == input_layout) {
(void)index_computation.emplace_back(SizeToLong(i), swc->cost_list[0]->computation_cost_);
(void)index_comm.emplace_back(SizeToLong(i), swc->cost_list[0]->communication_without_parameter_);
}
}
if (index_computation.empty()) {
if (index_comm.empty()) {
MS_LOG(ERROR) << "There in no available strategy for zero communication cost for reshape: " << name();
return -1;
}
if (index_computation.size() > 1) {
if (index_comm.size() > 1) {
MS_LOG(INFO) << "There are multiple strategies available for reshape: " << name();
}
std::sort(
index_computation.begin(), index_computation.end(),
index_comm.begin(), index_comm.end(),
[](const std::pair<size_t, double> &a, const std::pair<size_t, double> &b) { return a.second <= b.second; });
return index_computation[0].first;
return index_comm[0].first;
}
bool ReshapeInfo::CheckStrategyConsistencyByOutputLayout(int64_t swc_index, const TensorLayout &output_layout) {
@ -593,7 +593,13 @@ bool ReshapeInfo::CheckStrategyConsistencyByOutputLayout(int64_t swc_index, cons
return false;
}
const auto &swc = strategy_cost_[swc_index];
return swc->outputs_ptr[0].tensor_layout() == output_layout;
if (swc->outputs_ptr[0].tensor_layout() == output_layout) {
return true;
}
MS_LOG(WARNING) << name_ << "'s desired output layout is: " << output_layout.ToString() << ", while the selected "
<< "output layout is: " << swc->outputs_ptr[0].tensor_layout().ToString()
<< " and the input layout is: " << swc->inputs_ptr[0].tensor_layout().ToString();
return false;
}
bool ReshapeInfo::CheckStrategyConsistencyByInputLayout(int64_t swc_index, const TensorLayout &input_layout) {
@ -602,7 +608,13 @@ bool ReshapeInfo::CheckStrategyConsistencyByInputLayout(int64_t swc_index, const
return false;
}
const auto &swc = strategy_cost_[swc_index];
return swc->inputs_ptr[0].tensor_layout() == input_layout;
if (swc->inputs_ptr[0].tensor_layout() == input_layout) {
return true;
}
MS_LOG(WARNING) << name_ << "'s desired input layout is:" << input_layout.ToString() << ", while the selected "
<< "input layout is: " << swc->inputs_ptr[0].tensor_layout().ToString()
<< " and the output layout is: " << swc->outputs_ptr[0].tensor_layout().ToString();
return false;
}
TensorLayout ReshapeInfo::GetInputLayoutBySWCIndex(int64_t swc_index) {

View File

@ -70,9 +70,9 @@ class ReshapeInfo : public OperatorInfo {
int64_t next_operator_index() const { return next_operator_index_; }
int64_t GetSWCIndexByOutputLayoutWithZeroComm(const TensorLayout &);
int64_t GetSWCIndexByOutputLayout(const TensorLayout &);
int64_t GetSWCIndexByOutputLayoutWithMiniComm(const TensorLayout &);
int64_t GetSWCIndexByInputLayoutWithZeroComm(const TensorLayout &);
int64_t GetSWCIndexByInputLayout(const TensorLayout &);
int64_t GetSWCIndexByInputLayoutWithMiniComm(const TensorLayout &);
bool CheckStrategyConsistencyByOutputLayout(int64_t, const TensorLayout &);
bool CheckStrategyConsistencyByInputLayout(int64_t, const TensorLayout &);

View File

@ -647,7 +647,7 @@ void CreateEdgeBetweenTwoOps(const OperatorInfoPtr &prev_op_info, const Operator
if (ParallelContext::GetInstance()->sharding_propagation() && (prev_prim->name() == CAST) &&
(configured_stra_ops_.find(node_op_info) != configured_stra_ops_.end())) {
const auto next_op_stra = configured_stra_ops_[node_op_info];
const auto cast_stra = edge_ptr->GetPrevOpStrategyByNextOpStrategyWithZeroComm(next_op_stra);
const auto cast_stra = edge_ptr->GetPrevOpStrategyByNextOpStrategyWithMiniComm(next_op_stra);
if (cast_stra == nullptr) {
MS_LOG(EXCEPTION) << "No available strategy for: " << prev_op_info->name();
}

View File

@ -13,7 +13,6 @@
# limitations under the License.
import numpy as np
import pytest
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.common.api import _cell_graph_executor
@ -83,5 +82,4 @@ def test_auto_parallel_activation4():
strategy2 = None
strategy3 = ((8, 2),)
net = Net(_w1, strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile_net(net)
compile_net(net)

View File

@ -13,7 +13,6 @@
# limitations under the License.
import numpy as np
import pytest
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.common.api import _cell_graph_executor
@ -59,8 +58,7 @@ def test_auto_parallel_activation1():
strategy2 = ((8, 1),)
strategy3 = ((1, 8), (1, 1))
net = Net(_w1, strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile_net(net)
compile_net(net)
def test_auto_parallel_activation2():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0,