support reshape in sharding propagation:

1) using 'swc index of strategy_cost_' as reshape's selected strategy;
2) when encountering reshape in BFS, select the 'swc index' with zero communication cost;
3) when encountering a reshape that is already visited, check whether there exists communication between reshape and current operator. It is OK if communication happens between two configured operators;
4) currently, two consecutive reshapes are not supported;
5) adjusting BFS structure in graph_costmodel.cc;
6) adjusting some code in step_auto_parallel.cc to avoid cyclomatic complexity.
This commit is contained in:
Xiaoda Zhang 2021-11-09 15:39:02 +08:00
parent 5ada60b949
commit a772767265
9 changed files with 792 additions and 54 deletions

View File

@ -23,6 +23,7 @@
#include "frontend/parallel/auto_parallel/costmodel.h"
#include "frontend/parallel/auto_parallel/graph_costmodel.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
#include "frontend/parallel/ops_info/reshape_info.h"
namespace mindspore {
namespace parallel {
@ -394,6 +395,122 @@ StrategyPtr Edge::GetPrevOpStrategyByNextOpStrategyWithZeroComm(const StrategyPt
return prev_op_stras[0].first;
}
int64_t Edge::GetReshapeSWCIndexByNextOpStrategy(const StrategyPtr &next_op_stra, int64_t curr_depth,
const std::map<OperatorInfoPtr, StrategyPtr> &configured_ops) {
if (prev_op_->name().find(RESHAPEINFO) == std::string::npos) {
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s prev_op is not a Reshape.";
}
if (next_op_->name().find(RESHAPEINFO) != std::string::npos) {
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently.";
}
const auto &reshape_output_layout = next_op_->GetInputLayoutFromSWCByStrategy(next_op_stra, next_op_input_index_);
MS_LOG(INFO) << prev_op_->name() << "'s output layout: " << reshape_output_layout.ToString();
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(prev_op_);
// First, try to find the zero communication strategy.
auto swc_index = reshape_ptr->GetSWCIndexByOutputLayoutWithZeroComm(reshape_output_layout);
if (swc_index == -1 && curr_depth == 0) {
const auto &prev_edges = reshape_ptr->prev_edges();
if (!prev_edges.empty()) {
auto prev_edge = prev_edges[0];
if (configured_ops.find(prev_edge->prev_operator()) != configured_ops.end()) {
// Here, it is sure that Reshape's previous and next operators are both configured strategies,
// thus, it is OK that communication happens here.
swc_index = reshape_ptr->GetSWCIndexByOutputLayout(reshape_output_layout);
}
}
}
if (swc_index == -1) {
MS_LOG(EXCEPTION) << "No available strategy found at edge: " << edge_name_ << " for: " << prev_op_->name();
}
return swc_index;
}
int64_t Edge::GetReshapeSWCIndexByPrevOpStrategy(const StrategyPtr &prev_op_stra, int64_t curr_depth,
const std::map<OperatorInfoPtr, StrategyPtr> &configured_ops) {
if (next_op_->name().find(RESHAPEINFO) == std::string::npos) {
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s next_op is not a Reshape.";
}
if (prev_op_->name().find(RESHAPEINFO) != std::string::npos) {
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently.";
}
const auto &reshape_input_lyt = prev_op_->GetOutputLayoutFromSWCByStrategy(prev_op_stra, prev_op_output_index_);
MS_LOG(INFO) << next_op_->name() << "'s input layout: " << reshape_input_lyt.ToString();
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(next_op_);
// First, try to find the zero communication strategy.
auto swc_index = reshape_ptr->GetSWCIndexByInputLayoutWithZeroComm(reshape_input_lyt);
if (swc_index == -1 && curr_depth == 0) {
const auto &next_edges = reshape_ptr->succ_edges();
if (!next_edges.empty()) {
auto next_edge = next_edges[0];
if (configured_ops.find(next_edge->next_operator()) != configured_ops.end()) {
// Here, it is sure that Reshape's previous and next operators are both configured strategies,
// thus, it is OK that communication happens here.
swc_index = reshape_ptr->GetSWCIndexByInputLayout(reshape_input_lyt);
}
}
}
if (swc_index == -1) {
MS_LOG(EXCEPTION) << "No available strategy found at edge: " << edge_name_ << " for: " << next_op_->name();
}
return swc_index;
}
StrategyPtr Edge::GetPrevOpStrategyByReshapeSWCIndex(int64_t swc_index) {
if (next_op_->name().find(RESHAPEINFO) == std::string::npos) {
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s next_op is not a Reshape.";
}
if (prev_op_->name().find(RESHAPEINFO) != std::string::npos) {
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently.";
}
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(next_op_);
const auto &reshape_input_lyt = reshape_ptr->GetInputLayoutBySWCIndex(swc_index);
auto stra = prev_op_->GetStrategyFromSWCByOutputLayout(reshape_input_lyt, prev_op_output_index_);
if (stra == nullptr) {
MS_LOG(EXCEPTION) << "No available strategy found at edge: " << edge_name_ << " for: " << prev_op_->name();
}
return stra;
}
StrategyPtr Edge::GetNextOpStrategyByReshapeSWCIndex(int64_t swc_index) {
if (prev_op_->name().find(RESHAPEINFO) == std::string::npos) {
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s next_op is not a Reshape.";
}
if (next_op_->name().find(RESHAPEINFO) != std::string::npos) {
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently.";
}
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(prev_op_);
const auto &reshape_output_lyt = reshape_ptr->GetOutputLayoutBySWCIndex(swc_index);
auto stra = next_op_->GetStrategyFromSWCByInputLayout(reshape_output_lyt, next_op_input_index_);
if (stra == nullptr) {
MS_LOG(EXCEPTION) << "No available strategy found at edge: " << edge_name_ << " for: " << prev_op_->name();
}
return stra;
}
bool Edge::CheckStrategyConsistency(const std::map<OperatorInfoPtr, StrategyPtr> &configured_ops) {
auto prev_stra = prev_op_->selected_strategy();
auto next_stra = next_op_->selected_strategy();
if (prev_stra == nullptr) {
MS_LOG(EXCEPTION) << prev_op_->name() << "'s selected strategy is null!";
}
if (next_op_ == nullptr) {
MS_LOG(EXCEPTION) << next_op_->name() << "'s selected strategy is null!";
}
auto cost = GetCostByStrategyPair({prev_stra, next_stra});
if ((configured_ops.find(prev_op_) == configured_ops.end() ||
configured_ops.find(next_op_) == configured_ops.end()) &&
(cost == nullptr || cost->communication_cost_ > 0.0)) {
PrintStrategy(prev_op_->selected_strategy());
PrintStrategy(next_op_->selected_strategy());
MS_LOG(ERROR) << "There are redistribution cost occurs at edge: " << edge_name()
<< ", consider configuring sharding strategies for two operators."
<< " The full name of these two operators are: " << prev_op_->cnode()->fullname_with_scope()
<< " and " << next_op_->cnode()->fullname_with_scope();
return false;
}
return true;
}
void Edge::SetCostMapAndInputOutput(std::map<CostPtrKey, CostPtrList> &cost_map) {
cost_map_ = cost_map;
pre_op_output_.clear();

View File

@ -82,8 +82,17 @@ class Edge {
Status InitEdgeCost();
std::map<CostPtrKey, CostPtrList> GetCostMap() { return cost_map_; }
CostPtr GetCostByStrategyPair(const CostPtrKey &);
StrategyPtr GetNextOpStrategyByPrevOpStrategyWithZeroComm(const StrategyPtr &);
StrategyPtr GetPrevOpStrategyByNextOpStrategyWithZeroComm(const StrategyPtr &);
int64_t GetReshapeSWCIndexByNextOpStrategy(const StrategyPtr &next_op_stra, int64_t curr_depth,
const std::map<OperatorInfoPtr, StrategyPtr> &configured_ops);
int64_t GetReshapeSWCIndexByPrevOpStrategy(const StrategyPtr &prev_op_stra, int64_t curr_depth,
const std::map<OperatorInfoPtr, StrategyPtr> &configured_ops);
StrategyPtr GetPrevOpStrategyByReshapeSWCIndex(int64_t swc_index);
StrategyPtr GetNextOpStrategyByReshapeSWCIndex(int64_t swc_index);
bool CheckStrategyConsistency(const std::map<OperatorInfoPtr, StrategyPtr> &configured_ops);
void SetCostMapAndInputOutput(std::map<CostPtrKey, CostPtrList> &);
// For two operators u--->v, given the output tensor layout of u,
// and the input tensor layout of v, return the redistribution cost,
@ -166,7 +175,7 @@ class Edge {
// operators. The resulting edges are different from those normal edges, thus this Bool variable distinguishes them.
// If it is true, then we should guarantee that the strategy for output tensor consistent with the input tensor.
bool is_identity_edge;
CostPtr selected_cost_;
CostPtr selected_cost_ = nullptr;
// In the training phase, 'is_output_parameter_involve_' is used to mark whether the output of the previous operator
// is parameter-involved
int64_t is_output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved

View File

@ -101,84 +101,157 @@ void CostGraph::StrategyPropagate(const std::map<OperatorInfoPtr, StrategyPtr> &
}
}
void CheckShardingConsisitency(std::map<OperatorInfoPtr, StrategyPtr> configured_ops, const OperatorInfoPtr &curr_op,
const OperatorInfoPtr &another_op, const CostPtr &cost, const EdgePtr &edge) {
if ((configured_ops.find(another_op) == configured_ops.end()) &&
(cost == nullptr || cost->communication_cost_ != 0.0)) {
PrintStrategy(another_op->selected_strategy());
PrintStrategy(curr_op->selected_strategy());
MS_LOG(EXCEPTION) << "There are redistribution cost occurs at edge: " << edge->edge_name()
<< ", consider configuring sharding strategies for two operators."
<< " The full name of these two operators are: " << curr_op->cnode()->fullname_with_scope()
<< " and " << another_op->cnode()->fullname_with_scope();
void CheckVisitedEdgeConsistency(const EdgePtr &edge, std::map<OperatorInfoPtr, StrategyPtr> configured_ops) {
auto prev_op = edge->prev_operator();
auto next_op = edge->next_operator();
if (prev_op->name().find(RESHAPEINFO) != std::string::npos) {
const auto &reshape_output_lyt =
next_op->GetInputLayoutFromSWCByStrategy(next_op->selected_strategy(), edge->next_op_input_index());
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(prev_op);
auto consistency =
reshape_ptr->CheckStrategyConsistencyByOutputLayout(reshape_ptr->swc_index(), reshape_output_lyt);
if (!consistency) {
MS_LOG(EXCEPTION) << "Inconsistency occurred at edge: " << edge->edge_name();
}
} else if (next_op->name().find(RESHAPEINFO) != std::string::npos) {
const auto &reshape_input_lyt =
prev_op->GetOutputLayoutFromSWCByStrategy(prev_op->selected_strategy(), edge->prev_op_output_index());
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(next_op);
auto consistency = reshape_ptr->CheckStrategyConsistencyByInputLayout(reshape_ptr->swc_index(), reshape_input_lyt);
if (!consistency) {
MS_LOG(EXCEPTION) << "Inconsistency occurred at edge: " << edge->edge_name();
}
} else {
auto consistency = edge->CheckStrategyConsistency(configured_ops);
if (!consistency) {
MS_LOG(EXCEPTION) << "Inconsistency occurred at edge: " << edge->edge_name();
}
}
}
void CheckConfiguredSuccEdgeConsistency(const EdgePtr edge, std::map<OperatorInfoPtr, StrategyPtr> configured_ops,
int64_t curr_depth) {
auto curr_op = edge->prev_operator();
auto next_op = edge->next_operator();
if ((curr_op->name().find(RESHAPEINFO) != std::string::npos) && curr_depth > 1) {
auto next_op_conf_stra = configured_ops[next_op];
const auto &reshape_output_lyt =
next_op->GetInputLayoutFromSWCByStrategy(next_op_conf_stra, edge->next_op_input_index());
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(curr_op);
auto consistency =
reshape_ptr->CheckStrategyConsistencyByOutputLayout(reshape_ptr->swc_index(), reshape_output_lyt);
if (!consistency) {
MS_LOG(EXCEPTION) << "Inconsistency occurred at edge: " << edge->edge_name();
}
} else if (curr_op->name().find(RESHAPEINFO) == std::string::npos) {
const auto &next_op_conf_stra = configured_ops[next_op];
const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithZeroComm(curr_op->selected_strategy());
if ((next_op_conf_stra == nullptr) || (!next_op_conf_stra->IsEqual(next_op_stra))) {
MS_LOG(EXCEPTION) << "Sharding strategies should be configured on the boundary operators. "
<< "Currently reaching " << curr_op->name() << " and " << next_op->name() << "."
<< " The full name of these two operators are: " << curr_op->cnode()->fullname_with_scope()
<< " and " << next_op->cnode()->fullname_with_scope();
}
}
}
void CheckConfiguredPrevEdgeConsistency(const EdgePtr edge, std::map<OperatorInfoPtr, StrategyPtr> configured_ops,
int64_t curr_depth) {
auto curr_op = edge->next_operator();
auto prev_op = edge->prev_operator();
if (curr_op->name().find(RESHAPEINFO) != std::string::npos && curr_depth > 1) {
auto prev_op_conf_stra = configured_ops[prev_op];
const auto &reshape_input_lyt =
prev_op->GetOutputLayoutFromSWCByStrategy(prev_op_conf_stra, edge->prev_op_output_index());
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(curr_op);
auto consistency = reshape_ptr->CheckStrategyConsistencyByInputLayout(reshape_ptr->swc_index(), reshape_input_lyt);
if (!consistency) {
MS_LOG(EXCEPTION) << "Inconsistency occurred at edge: " << edge->edge_name();
}
} else if (curr_op->name().find(RESHAPEINFO) == std::string::npos) {
const auto &prev_op_conf_stra = configured_ops[prev_op];
const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithZeroComm(curr_op->selected_strategy());
if ((prev_op_conf_stra == nullptr) || (!prev_op_conf_stra->IsEqual(prev_op_stra))) {
MS_LOG(ERROR) << "curr_depth: " << curr_depth;
MS_LOG(EXCEPTION) << "Sharding strategies should be configured on the boundary operators. "
<< "Currently reaching " << prev_op->name() << " and " << curr_op->name() << "."
<< " The full name of these two operators are: " << prev_op->cnode()->fullname_with_scope()
<< " and " << curr_op->cnode()->fullname_with_scope();
}
}
}
void CostGraph::BFS(const OperatorInfoPtr &op, const StrategyPtr &op_stra,
std::map<OperatorInfoPtr, StrategyPtr> configured_ops, std::map<OperatorInfoPtr, bool> *visited) {
std::queue<std::pair<std::pair<OperatorInfoPtr, StrategyPtr>, size_t>> next_level;
(void)next_level.emplace(std::make_pair(op, op_stra), 0);
std::queue<std::pair<std::pair<OperatorInfoPtr, std::pair<StrategyPtr, int64_t>>, int64_t>> next_level;
(void)next_level.emplace(std::make_pair(op, std::make_pair(op_stra, -1)), 0);
while (!next_level.empty()) {
auto curr_op = next_level.front().first.first;
auto configured_stra = next_level.front().first.second;
auto configured_stra = next_level.front().first.second.first;
auto configured_stra_index = next_level.front().first.second.second;
auto curr_depth = next_level.front().second;
visited->at(curr_op) = true;
MS_LOG(INFO) << "curr_depth: " << curr_depth;
curr_op->SetSelectedStrategy(configured_stra, curr_depth);
if (curr_op->name().find(RESHAPEINFO) != std::string::npos) {
curr_op->set_swc_index(configured_stra_index, curr_depth);
} else {
curr_op->SetSelectedStrategy(configured_stra, curr_depth);
}
for (auto &edge : curr_op->succ_edges()) {
const auto &next_op = edge->next_operator();
if (visited->at(next_op)) {
const auto cost = edge->GetCostByStrategyPair({curr_op->selected_strategy(), next_op->selected_strategy()});
CheckShardingConsisitency(configured_ops, curr_op, next_op, cost, edge);
CheckVisitedEdgeConsistency(edge, configured_ops);
continue;
}
if ((curr_depth > 0) && (configured_ops.find(next_op) != configured_ops.end())) {
const auto &next_op_conf_stra = configured_ops[next_op];
const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithZeroComm(curr_op->selected_strategy());
if ((next_op_conf_stra == nullptr) || (!next_op_conf_stra->IsEqual(next_op_stra))) {
MS_LOG(EXCEPTION) << "Sharding strategies should be configured on the boundary operators. "
<< "Currently reaching " << curr_op->name() << " and " << next_op->name() << "."
<< " The full name of these two operators are: " << curr_op->cnode()->fullname_with_scope()
<< " and " << next_op->cnode()->fullname_with_scope();
}
CheckConfiguredSuccEdgeConsistency(edge, configured_ops, curr_depth);
}
if (configured_ops.find(next_op) != configured_ops.end()) {
continue;
}
const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithZeroComm(curr_op->selected_strategy());
if (next_op_stra == nullptr) {
PrintStrategy(curr_op->selected_strategy());
MS_LOG(EXCEPTION) << next_op->name() << "'s strategy is null in the edge: " << edge->edge_name();
if (curr_op->name().find(RESHAPEINFO) != std::string::npos) {
auto stra = edge->GetNextOpStrategyByReshapeSWCIndex(curr_op->swc_index());
(void)next_level.emplace(std::make_pair(next_op, std::make_pair(stra, -1)), curr_depth + 1);
} else if (next_op->name().find(RESHAPEINFO) != std::string::npos) {
auto swc_index =
edge->GetReshapeSWCIndexByPrevOpStrategy(curr_op->selected_strategy(), curr_depth, configured_ops);
(void)next_level.emplace(std::make_pair(next_op, std::make_pair(nullptr, swc_index)), curr_depth + 1);
} else {
const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithZeroComm(curr_op->selected_strategy());
if (next_op_stra == nullptr) {
PrintStrategy(curr_op->selected_strategy());
MS_LOG(EXCEPTION) << next_op->name() << "'s strategy is null in the edge: " << edge->edge_name();
}
(void)next_level.emplace(std::make_pair(next_op, std::make_pair(next_op_stra, -1)), curr_depth + 1);
}
(void)next_level.emplace(std::make_pair(next_op, next_op_stra), curr_depth + 1);
}
for (auto &edge : curr_op->prev_edges()) {
const auto &prev_op = edge->prev_operator();
if (visited->at(prev_op)) {
const auto cost = edge->GetCostByStrategyPair({prev_op->selected_strategy(), curr_op->selected_strategy()});
CheckShardingConsisitency(configured_ops, curr_op, prev_op, cost, edge);
CheckVisitedEdgeConsistency(edge, configured_ops);
continue;
}
if ((curr_depth > 0) && (configured_ops.find(prev_op) != configured_ops.end())) {
const auto &prev_op_conf_stra = configured_ops[prev_op];
const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithZeroComm(curr_op->selected_strategy());
if ((prev_op_conf_stra == nullptr) || (!prev_op_conf_stra->IsEqual(prev_op_stra))) {
MS_LOG(ERROR) << "curr_depth: " << curr_depth;
MS_LOG(EXCEPTION) << "Sharding strategies should be configured on the boundary operators. "
<< "Currently reaching " << prev_op->name() << " and " << curr_op->name() << "."
<< " The full name of these two operators are: " << prev_op->cnode()->fullname_with_scope()
<< " and " << curr_op->cnode()->fullname_with_scope();
}
CheckConfiguredPrevEdgeConsistency(edge, configured_ops, curr_depth);
}
if (configured_ops.find(prev_op) != configured_ops.end()) {
continue;
}
const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithZeroComm(curr_op->selected_strategy());
if (prev_op_stra == nullptr) {
PrintStrategy(curr_op->selected_strategy());
MS_LOG(EXCEPTION) << prev_op->name() << "'s strategy is null in the edge: " << edge->edge_name();
if (prev_op->name().find(RESHAPEINFO) != std::string::npos) {
auto swc_index =
edge->GetReshapeSWCIndexByNextOpStrategy(curr_op->selected_strategy(), curr_depth, configured_ops);
(void)next_level.emplace(std::make_pair(prev_op, std::make_pair(nullptr, swc_index)), curr_depth + 1);
} else if (curr_op->name().find(RESHAPEINFO) != std::string::npos) {
auto prev_stra = edge->GetPrevOpStrategyByReshapeSWCIndex(curr_op->swc_index());
(void)next_level.emplace(std::make_pair(prev_op, std::make_pair(prev_stra, -1)), curr_depth + 1);
} else {
const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithZeroComm(curr_op->selected_strategy());
if (prev_op_stra == nullptr) {
PrintStrategy(curr_op->selected_strategy());
MS_LOG(EXCEPTION) << prev_op->name() << "'s strategy is null in the edge: " << edge->edge_name();
}
(void)next_level.emplace(std::make_pair(prev_op, std::make_pair(prev_op_stra, -1)), curr_depth + 1);
}
(void)next_level.emplace(std::make_pair(prev_op, prev_op_stra), curr_depth + 1);
}
next_level.pop();
}

View File

@ -1363,6 +1363,50 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) {
return SUCCESS;
}
TensorLayout OperatorInfo::GetInputLayoutFromSWCByStrategy(StrategyPtr stra, size_t input_index) {
auto is_target = [&](std::shared_ptr<StrategyWithCost> swc) { return swc->strategy_ptr->IsEqual(stra); };
auto it = std::find_if(strategy_cost_.begin(), strategy_cost_.end(), is_target);
if (it != strategy_cost_.end()) {
const auto &input_info = (*it)->inputs_ptr[input_index];
return std::move(input_info.tensor_layout());
}
TensorLayout empty;
return empty;
}
TensorLayout OperatorInfo::GetOutputLayoutFromSWCByStrategy(StrategyPtr stra, size_t output_index) {
auto is_target = [&](std::shared_ptr<StrategyWithCost> swc) { return swc->strategy_ptr->IsEqual(stra); };
auto it = std::find_if(strategy_cost_.begin(), strategy_cost_.end(), is_target);
if (it != strategy_cost_.end()) {
const auto &output_info = (*it)->outputs_ptr[output_index];
return std::move(output_info.tensor_layout());
}
TensorLayout empty;
return empty;
}
StrategyPtr OperatorInfo::GetStrategyFromSWCByInputLayout(TensorLayout input_layout, size_t input_index) {
auto is_target = [&](std::shared_ptr<StrategyWithCost> swc) {
return swc->inputs_ptr[input_index].tensor_layout() == input_layout;
};
auto it = std::find_if(strategy_cost_.begin(), strategy_cost_.end(), is_target);
if (it != strategy_cost_.end()) {
return (*it)->strategy_ptr;
}
return nullptr;
}
StrategyPtr OperatorInfo::GetStrategyFromSWCByOutputLayout(TensorLayout output_layout, size_t output_index) {
auto is_target = [&](std::shared_ptr<StrategyWithCost> swc) {
return swc->outputs_ptr[output_index].tensor_layout() == output_layout;
};
auto it = std::find_if(strategy_cost_.begin(), strategy_cost_.end(), is_target);
if (it != strategy_cost_.end()) {
return (*it)->strategy_ptr;
}
return nullptr;
}
// Keep at most (1.0 / epsilon) number of available strategies for each operator.
void OperatorInfo::ApproximateStrategies() {
auto enable_approxi = CostModelContext::GetInstance()->dp_algo_enable_approxi();
@ -1683,6 +1727,12 @@ void OperatorInfo::SetSelectedStrategy(const StrategyPtr &s_strategy, size_t cur
selected_strategy_depth_ = SizeToLong(curr_depth);
}
void OperatorInfo::set_swc_index(int64_t swc, int64_t depth) {
MS_LOG(INFO) << "Set SWC index: " << swc << " for: " << name();
selected_strategy_depth_ = depth;
swc_index_ = swc;
}
CNodePtr OperatorInfo::cnode() {
MS_EXCEPTION_IF_NULL(cnode_);
return cnode_;

View File

@ -144,6 +144,14 @@ class OperatorInfo {
void SetSelectedStrategy(const StrategyPtr &s_strategy, size_t);
StrategyPtr selected_strategy() const { return selected_strategy_; }
CostPtr selected_cost() const { return selected_cost_; }
TensorLayout GetInputLayoutFromSWCByStrategy(StrategyPtr stra, size_t input_index);
TensorLayout GetOutputLayoutFromSWCByStrategy(StrategyPtr stra, size_t output_index);
StrategyPtr GetStrategyFromSWCByInputLayout(TensorLayout input_layout, size_t input_index);
StrategyPtr GetStrategyFromSWCByOutputLayout(TensorLayout output_layout, size_t output_index);
void set_swc_index(int64_t, int64_t);
int64_t swc_index() { return swc_index_; }
// Approximate the list of available strategies
void ApproximateStrategies();
// Make the list of available strategies exact and re-init the related edges incident to this operator
@ -293,6 +301,7 @@ class OperatorInfo {
private:
OperatorCostPtr operator_cost_;
std::vector<TypePtr> outputs_type_;
int64_t swc_index_ = -1;
};
Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy);

View File

@ -431,7 +431,7 @@ std::vector<StrategyPtr> ReshapeInfo::GenerateOpStrategies(int64_t) {
return sp_vector;
}
Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs,
Status ReshapeInfo::GenerateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs,
const std::vector<std::shared_ptr<StrategyWithCost>> &next_stra_costs,
int64_t out_index, int64_t in_index, bool is_prev_param,
bool is_next_reshape) {
@ -488,7 +488,137 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<Stra
if (strategy_cost_.empty()) {
return FAILED;
}
MS_LOG(INFO) << "Print " << name() << "'s 'strategy_cost':";
for (auto &swc : strategy_cost_) {
MS_LOG(INFO) << name() << "'s strategy:";
PrintStrategy(swc->strategy_ptr);
MS_LOG(INFO) << "The corresponding cost: " << swc->cost_list[0]->computation_cost_ << ", "
<< swc->cost_list[0]->communication_cost_ << ", "
<< swc->cost_list[0]->communication_without_parameter_;
MS_LOG(INFO) << "Input layout: " << swc->inputs_ptr[0].tensor_layout().ToString();
MS_LOG(INFO) << "Output layout: " << swc->outputs_ptr[0].tensor_layout().ToString();
}
return SUCCESS;
}
int64_t ReshapeInfo::GetSWCIndexByOutputLayoutWithZeroComm(const TensorLayout &output_layout) {
std::vector<std::pair<int64_t, double>> index_computation;
for (size_t i = 0; i < strategy_cost_.size(); ++i) {
const auto &swc = strategy_cost_[i];
if (swc->outputs_ptr[0].tensor_layout() == output_layout &&
swc->cost_list[0]->communication_without_parameter_ == 0.0) {
(void)index_computation.emplace_back(SizeToLong(i), swc->cost_list[0]->computation_cost_);
}
}
if (index_computation.empty()) {
MS_LOG(ERROR) << "There in no available strategy for zero communication cost for reshape: " << name();
return -1;
}
if (index_computation.size() > 1) {
MS_LOG(INFO) << "There are multiple strategies available for reshape: " << name();
}
std::sort(
index_computation.begin(), index_computation.end(),
[](const std::pair<size_t, double> &a, const std::pair<size_t, double> &b) { return a.second <= b.second; });
return index_computation[0].first;
}
int64_t ReshapeInfo::GetSWCIndexByOutputLayout(const TensorLayout &output_layout) {
std::vector<std::pair<int64_t, double>> index_computation;
for (size_t i = 0; i < strategy_cost_.size(); ++i) {
const auto &swc = strategy_cost_[i];
if (swc->outputs_ptr[0].tensor_layout() == output_layout) {
(void)index_computation.emplace_back(SizeToLong(i), swc->cost_list[0]->computation_cost_);
}
}
if (index_computation.empty()) {
MS_LOG(ERROR) << "There in no available strategy for zero communication cost for reshape: " << name();
return -1;
}
if (index_computation.size() > 1) {
MS_LOG(INFO) << "There are multiple strategies available for reshape: " << name();
}
std::sort(
index_computation.begin(), index_computation.end(),
[](const std::pair<size_t, double> &a, const std::pair<size_t, double> &b) { return a.second <= b.second; });
return index_computation[0].first;
}
int64_t ReshapeInfo::GetSWCIndexByInputLayoutWithZeroComm(const TensorLayout &input_layout) {
std::vector<std::pair<int64_t, double>> index_computation;
for (size_t i = 0; i < strategy_cost_.size(); ++i) {
const auto &swc = strategy_cost_[i];
if (swc->inputs_ptr[0].tensor_layout() == input_layout &&
swc->cost_list[0]->communication_without_parameter_ == 0.0) {
(void)index_computation.emplace_back(SizeToLong(i), swc->cost_list[0]->computation_cost_);
}
}
if (index_computation.empty()) {
MS_LOG(ERROR) << "There in no available strategy for zero communication cost for reshape: " << name();
return -1;
}
if (index_computation.size() > 1) {
MS_LOG(INFO) << "There are multiple strategies available for reshape: " << name();
}
std::sort(
index_computation.begin(), index_computation.end(),
[](const std::pair<size_t, double> &a, const std::pair<size_t, double> &b) { return a.second <= b.second; });
return index_computation[0].first;
}
int64_t ReshapeInfo::GetSWCIndexByInputLayout(const TensorLayout &input_layout) {
std::vector<std::pair<int64_t, double>> index_computation;
for (size_t i = 0; i < strategy_cost_.size(); ++i) {
const auto &swc = strategy_cost_[i];
if (swc->inputs_ptr[0].tensor_layout() == input_layout) {
(void)index_computation.emplace_back(SizeToLong(i), swc->cost_list[0]->computation_cost_);
}
}
if (index_computation.empty()) {
MS_LOG(ERROR) << "There in no available strategy for zero communication cost for reshape: " << name();
return -1;
}
if (index_computation.size() > 1) {
MS_LOG(INFO) << "There are multiple strategies available for reshape: " << name();
}
std::sort(
index_computation.begin(), index_computation.end(),
[](const std::pair<size_t, double> &a, const std::pair<size_t, double> &b) { return a.second <= b.second; });
return index_computation[0].first;
}
bool ReshapeInfo::CheckStrategyConsistencyByOutputLayout(int64_t swc_index, const TensorLayout &output_layout) {
if (swc_index == -1 || swc_index >= SizeToLong(strategy_cost_.size())) {
MS_LOG(ERROR) << "The strategy_index: " << swc_index << " is out of range.";
return false;
}
const auto &swc = strategy_cost_[swc_index];
return swc->outputs_ptr[0].tensor_layout() == output_layout;
}
bool ReshapeInfo::CheckStrategyConsistencyByInputLayout(int64_t swc_index, const TensorLayout &input_layout) {
if (swc_index == -1 || swc_index >= SizeToLong(strategy_cost_.size())) {
MS_LOG(ERROR) << "The strategy_index: " << swc_index << " is out of range.";
return false;
}
const auto &swc = strategy_cost_[swc_index];
return swc->inputs_ptr[0].tensor_layout() == input_layout;
}
TensorLayout ReshapeInfo::GetInputLayoutBySWCIndex(int64_t swc_index) {
if (swc_index == -1 || swc_index >= SizeToLong(strategy_cost_.size())) {
MS_LOG(EXCEPTION) << "The strategy_index: " << swc_index << " is out of range.";
}
const auto &swc = strategy_cost_[swc_index];
return std::move(swc->inputs_ptr[0].tensor_layout());
}
TensorLayout ReshapeInfo::GetOutputLayoutBySWCIndex(int64_t swc_index) {
if (swc_index == -1 || swc_index >= SizeToLong(strategy_cost_.size())) {
MS_LOG(EXCEPTION) << "The strategy_index: " << swc_index << " is out of range.";
}
const auto &swc = strategy_cost_[swc_index];
return std::move(swc->outputs_ptr[0].tensor_layout());
}
} // namespace parallel
} // namespace mindspore

View File

@ -58,7 +58,7 @@ class ReshapeInfo : public OperatorInfo {
void set_next_operator_name(const std::string &next_name) { next_operator_name_ = next_name; }
void set_pre_operator_index(int64_t pre_index) { pre_operator_index_ = pre_index; }
void set_next_operator_index(int64_t next_index) { next_operator_index_ = next_index; }
Status GenetateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs,
Status GenerateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs,
const std::vector<std::shared_ptr<StrategyWithCost>> &next_stra_costs, int64_t out_index,
int64_t in_index, bool is_prev_param, bool is_next_reshape);
Status GenerateStrategies(int64_t stage_id) override;
@ -69,6 +69,16 @@ class ReshapeInfo : public OperatorInfo {
int64_t pre_operator_index() const { return pre_operator_index_; }
int64_t next_operator_index() const { return next_operator_index_; }
int64_t GetSWCIndexByOutputLayoutWithZeroComm(const TensorLayout &);
int64_t GetSWCIndexByOutputLayout(const TensorLayout &);
int64_t GetSWCIndexByInputLayoutWithZeroComm(const TensorLayout &);
int64_t GetSWCIndexByInputLayout(const TensorLayout &);
bool CheckStrategyConsistencyByOutputLayout(int64_t, const TensorLayout &);
bool CheckStrategyConsistencyByInputLayout(int64_t, const TensorLayout &);
TensorLayout GetInputLayoutBySWCIndex(int64_t);
TensorLayout GetOutputLayoutBySWCIndex(int64_t);
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override;

View File

@ -665,6 +665,14 @@ void CreateEdgeBetweenTwoOps(const OperatorInfoPtr &prev_op_info, const Operator
(*edge_count)++;
}
void CheckAndApplyApproximation() {
// If 'approximation' is enabled, the edges need to be checked have effective costs.
auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
if (approximation) {
entire_costgraph->CheckApproximateCostGraphEdges();
}
}
void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
// Step 2
MS_LOG(INFO) << "Constructing edges for cost graph begins.";
@ -736,11 +744,7 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
}
MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name();
}
// If 'approximation' is enabled, the edges need to be checked have effective costs.
auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
if (approximation) {
entire_costgraph->CheckApproximateCostGraphEdges();
}
CheckAndApplyApproximation();
MS_LOG(INFO) << "Constructing edges for cost graph ends.";
}
@ -913,7 +917,7 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
reshape_info->set_next_operator_index(in_index);
}
bool is_prev_param = pre_node->isa<Parameter>();
if (reshape_info->GenetateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param,
if (reshape_info->GenerateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param,
is_next_reshape) != SUCCESS) {
MS_LOG(EXCEPTION) << "reshape generate strategy_costs failed!";
}

View File

@ -0,0 +1,336 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor, context
from mindspore.common import dtype as mstype
from mindspore.common import Parameter
from mindspore.common.api import _cell_graph_executor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from tests.ut.python.ops.test_math_ops import VirtualLoss
grad_all = C.GradOperation(get_all=True)
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x):
predict = self.network(x)
return self.loss(predict)
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x):
return grad_all(self.network)(x)
class NetWithLossTwoInput(nn.Cell):
def __init__(self, network):
super(NetWithLossTwoInput, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x, y):
predict = self.network(x, y)
return self.loss(predict)
class GradWrapTwoInput(nn.Cell):
def __init__(self, network):
super(GradWrapTwoInput, self).__init__()
self.network = network
def construct(self, x, y):
return grad_all(self.network)(x, y)
def compile_graph(net, device_num, x):
context.set_auto_parallel_context(device_num=device_num, global_rank=0, parallel_mode="auto_parallel",
sharding_propagation=True)
net.set_auto_parallel()
net.set_train()
_cell_graph_executor.compile(net, x)
def compile_graph_two_input(net, device_num, x, y):
context.set_auto_parallel_context(device_num=device_num, global_rank=0, parallel_mode="auto_parallel",
sharding_propagation=True)
net.set_auto_parallel()
net.set_train()
_cell_graph_executor.compile(net, x, y)
def test_reshape_reshape():
"""
Feature: Sharding propagation for Reshape.
Description: ReLU->Reshape
Expectation: compile done without error.
"""
device_num = 8
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.reshape = P.Reshape()
self.relu = P.ReLU().shard(((1, 1, 1, 1),))
def construct(self, x):
x = self.relu(x)
out = self.reshape(x, (64, 28))
out = self.reshape(out, (64, 28, 1))
return out
x = Tensor(np.ones([device_num * 8, 28, 1, 1]), dtype=ms.float32)
net = GradWrap(NetWithLoss(Net()))
compile_graph(net, device_num, x)
def test_reshape_auto_1():
"""
Feature: Sharding propagation for Reshape.
Description: ReLU->Reshape->MatMul
Expectation: compile done without error.
"""
device_num = 8
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.relu = P.ReLU().shard(((1, 1, 1, 1),))
self.reshape = P.Reshape()
self.matmul = P.MatMul().shard(((2, 1), (1, 4)))
self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight")
def construct(self, x):
x = self.relu(x)
out = self.reshape(x, (64, 28))
out = self.matmul(out, self.matmul_weight)
return out
x = Tensor(np.ones([device_num * 8, 28, 1, 1]), dtype=ms.float32)
net = GradWrap(NetWithLoss(Net()))
compile_graph(net, device_num, x)
def test_reshape_auto_2():
"""
Feature: Sharding propagation for Reshape.
Description: ReLU->Reshape->MatMul->Reshape->Add
Expectation: compile done without error.
"""
device_num = 8
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.relu = P.ReLU()
self.relu2 = P.ReLU()
self.reshape = P.Reshape()
self.matmul = P.MatMul().shard(((2, 1), (1, 4)))
self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight")
self.add = P.Add().shard(((2, 4), (2, 4)))
self.add_weight = Parameter(Tensor(np.ones([128, 32]), dtype=ms.float32), name="weight1")
def construct(self, x):
out = self.relu(x)
out = self.relu2(out)
out = self.reshape(out, (64, 28))
out = self.matmul(out, self.matmul_weight)
out = self.reshape(out, (128, 32))
out = self.add(out, self.add_weight)
return out
x = Tensor(np.ones([device_num * 8, 28, 1, 1]), dtype=ms.float32)
net = GradWrap(NetWithLoss(Net()))
compile_graph(net, device_num, x)
def test_reshape_auto_3():
"""
Feature: Sharding propagation for Reshape.
Description: Mul->Add->Cast->Reshape->Cast->ReduceMean
Expectation: compile done without error.
"""
device_num = 8
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.gamma = Parameter(Tensor(np.ones([1024]), dtype=ms.float32), name="gamma")
self.beta = Parameter(Tensor(np.ones([1024]), dtype=ms.float32), name="beta")
self.add = P.TensorAdd().shard(((8, 1, 1), (1,)))
self.mul = P.Mul().shard(((8, 1, 1), (1,)))
self.mean = P.ReduceMean(keep_dims=True).shard(((8, 1),))
self.reshape = P.Reshape()
self.dtype1 = mstype.float16
self.dtype2 = mstype.float32
def construct(self, x):
out = self.add(self.mul(x, self.gamma), self.beta)
out = F.cast(out, self.dtype1)
out = self.reshape(out, (-1, 1024))
out = F.cast(out, self.dtype2)
out = self.mean(out, -1)
return out
x = Tensor(np.ones([2048, 30, 1024]), dtype=ms.float32)
net = GradWrap(NetWithLoss(Net()))
compile_graph(net, device_num, x)
def test_reshape_auto_4():
"""
Feature: Sharding propagation for Reshape.
Description: Mul->Add->Cast->Reshape->Cast->ReduceMean
Expectation: compile done without error.
"""
device_num = 8
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.gamma = Parameter(Tensor(np.ones([1024]), dtype=ms.float32), name="gamma")
self.beta = Parameter(Tensor(np.ones([1024]), dtype=ms.float32), name="beta")
self.add = P.TensorAdd().shard(((8, 1, 1), (1,)))
self.mul = P.Mul().shard(((8, 1, 1), (1,)))
self.mean = P.ReduceMean(keep_dims=True)
self.reshape = P.Reshape()
self.dtype1 = mstype.float16
self.dtype2 = mstype.float32
def construct(self, x):
out = self.add(self.mul(x, self.gamma), self.beta)
out = F.cast(out, self.dtype1)
out = self.reshape(out, (-1, 1024))
out = F.cast(out, self.dtype2)
out = self.mean(out, -1)
return out
x = Tensor(np.ones([2048, 30, 1024]), dtype=ms.float32)
net = GradWrap(NetWithLoss(Net()))
compile_graph(net, device_num, x)
def test_reshape_auto_5():
"""
Feature: Sharding propagation for Reshape.
Description: Mul->Add->Cast->Reshape->Cast->ReduceMean
Expectation: compile done without error.
"""
device_num = 8
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.gamma = Parameter(Tensor(np.ones([1024]), dtype=ms.float32), name="gamma")
self.beta = Parameter(Tensor(np.ones([1024]), dtype=ms.float32), name="beta")
self.add = P.TensorAdd().shard(((8, 1, 1), (1,)))
self.mul = P.Mul()
self.mean = P.ReduceMean(keep_dims=True).shard(((2, 4),))
self.reshape = P.Reshape()
self.dtype1 = mstype.float16
self.dtype2 = mstype.float32
def construct(self, x):
out = self.add(self.mul(x, self.gamma), self.beta)
out = self.reshape(out, (-1, 1024))
out = self.mean(out, -1)
return out
x = Tensor(np.ones([2048, 30, 1024]), dtype=ms.float32)
net = GradWrap(NetWithLoss(Net()))
compile_graph(net, device_num, x)
def test_reshape_auto_6():
"""
Feature: Sharding propagation for Reshape.
Description: Reshape->ReLU->Mul->Reshape->Add->Mul->Reshape->Add
Expectation: compile done without error.
"""
device_num = 8
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.relu = P.ReLU()
self.mul = P.Mul().shard(((8, 1, 1), (8, 1, 1)))
self.reshape = P.Reshape()
self.reduce_sum = P.ReduceSum()
self.wide_w = Parameter(Tensor(np.ones([8, 1024*8, 64]), dtype=ms.float32), name="weight")
def construct(self, x, y):
mask = self.reshape(y, (8, 1024*8, 1))
w_id = self.relu(x)
wx = self.mul(w_id, mask)
wide_out = self.reshape(self.reduce_sum(wx, 1), (-1, 1))
deep_id = x + self.wide_w
vx = self.mul(deep_id, mask)
deep_in = self.reshape(vx, (-1, 1024*8*64))
out = wide_out + deep_in
return out
x = Tensor(np.ones([8, 1024*device_num, 1]), dtype=ms.float32)
y = Tensor(np.ones([8, 1024*device_num]), dtype=ms.float32)
net = GradWrapTwoInput(NetWithLossTwoInput(Net()))
compile_graph_two_input(net, device_num, x, y)
def test_reshape_depend_reshape():
"""
Feature: Sharding propagation for Reshape.
Description: Mul->ReLU->Reshape->Reshape->Add
Expectation: compile with error.
"""
device_num = 8
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.reshape1 = P.Reshape()
self.reshape2 = P.Reshape()
self.relu = P.ReLU()
self.depend = P.Depend()
self.mul = P.Mul().shard(((2, 4), (2, 4)))
self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight")
self.add = P.Add().shard(((4, 2), (4, 2)))
def construct(self, x, y):
out1 = self.mul(x, self.mul_weight)
y = self.relu(y)
out2 = self.reshape1(y, (96, 32, 4))
out3 = self.depend(out2, out1)
out3 = self.reshape2(out3, (128, 96))
out = out1 + out3
return out
class NetWithLoss1(nn.Cell):
def __init__(self, network):
super(NetWithLoss1, self).__init__()
self.mean = P.ReduceMean(keep_dims=False)
self.network = network
def construct(self, x, y):
predict = self.network(x, y)
return self.mean(predict, ())
x = Tensor(np.ones([128, 96]), dtype=ms.float32)
y = Tensor(np.ones([256, 48]), dtype=ms.float32)
net = GradWrapTwoInput(NetWithLoss1(Net()))
with pytest.raises(RuntimeError):
compile_graph_two_input(net, device_num, x, y)