!27432 sharding propatation for pangu

Merge pull request !27432 from bichaoyang/master
This commit is contained in:
i-robot 2021-12-15 02:25:42 +00:00 committed by Gitee
commit bf0fce0ebe
7 changed files with 34 additions and 14 deletions

View File

@ -374,7 +374,7 @@ StrategyPtr Edge::GetNextOpStrategyByPrevOpStrategyWithMiniComm(const StrategyPt
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge_name();
std::sort(next_stras.begin(), next_stras.end(),
[](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
return a.second < b.second;
return a.second != b.second ? a.second < b.second : a.first->PartitionNum() > b.first->PartitionNum();
});
return next_stras[0].first;
}
@ -384,7 +384,7 @@ StrategyPtr Edge::GetNextOpStrategyByPrevOpStrategyWithMiniComm(const StrategyPt
}
std::sort(next_op_stras.begin(), next_op_stras.end(),
[](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
return a.second < b.second;
return a.second != b.second ? a.second < b.second : a.first->PartitionNum() > b.first->PartitionNum();
});
return next_op_stras[0].first;
}
@ -414,7 +414,7 @@ StrategyPtr Edge::GetPrevOpStrategyByNextOpStrategyWithMiniComm(const StrategyPt
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge_name();
std::sort(prev_stras.begin(), prev_stras.end(),
[](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
return a.second < b.second;
return a.second != b.second ? a.second < b.second : a.first->PartitionNum() > b.first->PartitionNum();
});
return prev_stras[0].first;
}
@ -424,7 +424,7 @@ StrategyPtr Edge::GetPrevOpStrategyByNextOpStrategyWithMiniComm(const StrategyPt
}
std::sort(prev_op_stras.begin(), prev_op_stras.end(),
[](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
return a.second < b.second;
return a.second != b.second ? a.second < b.second : a.first->PartitionNum() > b.first->PartitionNum();
});
return prev_op_stras[0].first;
}

View File

@ -34,6 +34,10 @@ using CostPtrKey = std::pair<StrategyPtr, StrategyPtr>;
using OperatorInfoPtr = std::shared_ptr<mindspore::parallel::OperatorInfo>;
using EdgePtr = std::shared_ptr<mindspore::parallel::Edge>;
struct OpsPtrCompare {
bool operator()(const OperatorInfoPtr &a, const OperatorInfoPtr &b) const { return a->name().compare(b->name()) < 0; }
};
class Edge {
// An 'Edge' connects two Operators in the CostGraph.
public:

View File

@ -88,7 +88,7 @@ bool CostGraph::IsEdgeInCostGraph(const std::string &test_edge_name, size_t outp
return false;
}
void CostGraph::StrategyPropagate(const std::map<OperatorInfoPtr, StrategyPtr> &ops_stras) {
void CostGraph::StrategyPropagate(const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &ops_stras) {
if (ops_stras.empty()) {
MS_LOG(EXCEPTION) << "There is no operator that is configured sharding strategy.";
}
@ -129,10 +129,11 @@ void CheckVisitedEdgeConsistency(const EdgePtr &edge) {
}
}
void CheckConfiguredSuccEdgeConsistency(const EdgePtr edge, std::map<OperatorInfoPtr, StrategyPtr> configured_ops) {
void CheckConfiguredSuccEdgeConsistency(const EdgePtr &edge,
const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &configured_ops) {
auto curr_op = edge->prev_operator();
auto next_op = edge->next_operator();
auto next_op_conf_stra = configured_ops[next_op];
auto next_op_conf_stra = configured_ops.at(next_op);
if (curr_op->IsReshape()) {
const auto &reshape_output_lyt =
next_op->GetInputLayoutFromSWCByStrategy(next_op_conf_stra, edge->next_op_input_index());
@ -150,10 +151,11 @@ void CheckConfiguredSuccEdgeConsistency(const EdgePtr edge, std::map<OperatorInf
}
}
void CheckConfiguredPrevEdgeConsistency(const EdgePtr edge, std::map<OperatorInfoPtr, StrategyPtr> configured_ops) {
void CheckConfiguredPrevEdgeConsistency(const EdgePtr &edge,
const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &configured_ops) {
auto curr_op = edge->next_operator();
auto prev_op = edge->prev_operator();
auto prev_op_conf_stra = configured_ops[prev_op];
auto prev_op_conf_stra = configured_ops.at(prev_op);
if (curr_op->IsReshape()) {
const auto &reshape_input_lyt =
prev_op->GetOutputLayoutFromSWCByStrategy(prev_op_conf_stra, edge->prev_op_output_index());
@ -171,7 +173,8 @@ void CheckConfiguredPrevEdgeConsistency(const EdgePtr edge, std::map<OperatorInf
}
void CostGraph::BFS(const OperatorInfoPtr &op, const StrategyPtr &op_stra,
std::map<OperatorInfoPtr, StrategyPtr> configured_ops, std::map<OperatorInfoPtr, bool> *visited) {
const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> configured_ops,
std::map<OperatorInfoPtr, bool> *visited) {
std::queue<std::pair<std::pair<OperatorInfoPtr, std::pair<StrategyPtr, int64_t>>, int64_t>> next_level;
(void)next_level.emplace(std::make_pair(op, std::make_pair(op_stra, -1)), 0);
while (!next_level.empty()) {
@ -188,6 +191,7 @@ void CostGraph::BFS(const OperatorInfoPtr &op, const StrategyPtr &op_stra,
}
for (auto &edge : curr_op->succ_edges()) {
const auto &next_op = edge->next_operator();
MS_LOG(DEBUG) << "forward propagation at " << curr_op->name() << "->" << next_op->name();
if (visited->at(next_op)) {
CheckVisitedEdgeConsistency(edge);
continue;
@ -215,6 +219,7 @@ void CostGraph::BFS(const OperatorInfoPtr &op, const StrategyPtr &op_stra,
}
for (auto &edge : curr_op->prev_edges()) {
const auto &prev_op = edge->prev_operator();
MS_LOG(DEBUG) << "backpropagation at " << curr_op->name() << "->" << prev_op->name();
if (visited->at(prev_op)) {
CheckVisitedEdgeConsistency(edge);
continue;

View File

@ -52,8 +52,8 @@ class CostGraph {
}
void RemoveOperator(const OperatorInfoPtr &op);
bool IsOperatorInCostGraph(const OperatorInfoPtr &op);
void StrategyPropagate(const std::map<OperatorInfoPtr, StrategyPtr> &);
void BFS(const OperatorInfoPtr &, const StrategyPtr &, std::map<OperatorInfoPtr, StrategyPtr>,
void StrategyPropagate(const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &);
void BFS(const OperatorInfoPtr &, const StrategyPtr &, const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare>,
std::map<OperatorInfoPtr, bool> *);
// the edge is in the form: u --> v
void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge);

View File

@ -267,6 +267,7 @@ constexpr char SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SoftmaxCrossEntropyWithLog
constexpr char SIGMOID_CROSS_ENTROPY_WITH_LOGITS[] = "SigmoidCrossEntropyWithLogits";
constexpr char MATMUL[] = "MatMul";
constexpr char GELU[] = "GeLU";
constexpr char FAST_GELU[] = "FastGeLU";
constexpr char TANH[] = "Tanh";
constexpr char RECEIVE[] = "Receive";
constexpr char SEND[] = "Send";

View File

@ -155,7 +155,7 @@ bool IsElementWiseOperator(const std::string &op_name) {
bool IsSplittableOperator(const std::string &op_name) {
// clang-format off
static const std::set<std::string> splittable_op =
{MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU,
{MATMUL, TRANSPOSE, GELU, FAST_GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU,
FLOORDIV, L2_NORMALIZE, ADD, MAXPOOL, AVGPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK,
REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING,
MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, STACK,
@ -232,7 +232,7 @@ bool IsOperatorsInTwoSeparateLoops(const CNodePtr &a_cnode, const CNodePtr &b_cn
}
// 'configured_stra_ops_' includes all operators that are configured sharding strategies.
std::map<OperatorInfoPtr, StrategyPtr> configured_stra_ops_;
std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> configured_stra_ops_;
void InitCostGraph() {
if (entire_costgraph == nullptr) {
entire_costgraph = std::make_shared<CostGraph>();

View File

@ -74,6 +74,16 @@ class Strategy {
return true;
}
int64_t PartitionNum() {
int64_t divergence = 1;
for (size_t i = 0; i < inputs_.size(); ++i) {
for (size_t j = 0; j < inputs_[i].size(); ++j) {
divergence *= inputs_[i][j];
}
}
return divergence;
}
// Include 'another_stra' into this strategy
void CoverStrategy(const StrategyPtr &another_stra) {
internal_stragies_.push_back(another_stra);