forked from mindspore-Ecosystem/mindspore
!27432 sharding propatation for pangu
Merge pull request !27432 from bichaoyang/master
This commit is contained in:
commit
bf0fce0ebe
|
@ -374,7 +374,7 @@ StrategyPtr Edge::GetNextOpStrategyByPrevOpStrategyWithMiniComm(const StrategyPt
|
|||
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge_name();
|
||||
std::sort(next_stras.begin(), next_stras.end(),
|
||||
[](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
|
||||
return a.second < b.second;
|
||||
return a.second != b.second ? a.second < b.second : a.first->PartitionNum() > b.first->PartitionNum();
|
||||
});
|
||||
return next_stras[0].first;
|
||||
}
|
||||
|
@ -384,7 +384,7 @@ StrategyPtr Edge::GetNextOpStrategyByPrevOpStrategyWithMiniComm(const StrategyPt
|
|||
}
|
||||
std::sort(next_op_stras.begin(), next_op_stras.end(),
|
||||
[](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
|
||||
return a.second < b.second;
|
||||
return a.second != b.second ? a.second < b.second : a.first->PartitionNum() > b.first->PartitionNum();
|
||||
});
|
||||
return next_op_stras[0].first;
|
||||
}
|
||||
|
@ -414,7 +414,7 @@ StrategyPtr Edge::GetPrevOpStrategyByNextOpStrategyWithMiniComm(const StrategyPt
|
|||
MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge_name();
|
||||
std::sort(prev_stras.begin(), prev_stras.end(),
|
||||
[](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
|
||||
return a.second < b.second;
|
||||
return a.second != b.second ? a.second < b.second : a.first->PartitionNum() > b.first->PartitionNum();
|
||||
});
|
||||
return prev_stras[0].first;
|
||||
}
|
||||
|
@ -424,7 +424,7 @@ StrategyPtr Edge::GetPrevOpStrategyByNextOpStrategyWithMiniComm(const StrategyPt
|
|||
}
|
||||
std::sort(prev_op_stras.begin(), prev_op_stras.end(),
|
||||
[](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
|
||||
return a.second < b.second;
|
||||
return a.second != b.second ? a.second < b.second : a.first->PartitionNum() > b.first->PartitionNum();
|
||||
});
|
||||
return prev_op_stras[0].first;
|
||||
}
|
||||
|
|
|
@ -34,6 +34,10 @@ using CostPtrKey = std::pair<StrategyPtr, StrategyPtr>;
|
|||
using OperatorInfoPtr = std::shared_ptr<mindspore::parallel::OperatorInfo>;
|
||||
using EdgePtr = std::shared_ptr<mindspore::parallel::Edge>;
|
||||
|
||||
struct OpsPtrCompare {
|
||||
bool operator()(const OperatorInfoPtr &a, const OperatorInfoPtr &b) const { return a->name().compare(b->name()) < 0; }
|
||||
};
|
||||
|
||||
class Edge {
|
||||
// An 'Edge' connects two Operators in the CostGraph.
|
||||
public:
|
||||
|
|
|
@ -88,7 +88,7 @@ bool CostGraph::IsEdgeInCostGraph(const std::string &test_edge_name, size_t outp
|
|||
return false;
|
||||
}
|
||||
|
||||
void CostGraph::StrategyPropagate(const std::map<OperatorInfoPtr, StrategyPtr> &ops_stras) {
|
||||
void CostGraph::StrategyPropagate(const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &ops_stras) {
|
||||
if (ops_stras.empty()) {
|
||||
MS_LOG(EXCEPTION) << "There is no operator that is configured sharding strategy.";
|
||||
}
|
||||
|
@ -129,10 +129,11 @@ void CheckVisitedEdgeConsistency(const EdgePtr &edge) {
|
|||
}
|
||||
}
|
||||
|
||||
void CheckConfiguredSuccEdgeConsistency(const EdgePtr edge, std::map<OperatorInfoPtr, StrategyPtr> configured_ops) {
|
||||
void CheckConfiguredSuccEdgeConsistency(const EdgePtr &edge,
|
||||
const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &configured_ops) {
|
||||
auto curr_op = edge->prev_operator();
|
||||
auto next_op = edge->next_operator();
|
||||
auto next_op_conf_stra = configured_ops[next_op];
|
||||
auto next_op_conf_stra = configured_ops.at(next_op);
|
||||
if (curr_op->IsReshape()) {
|
||||
const auto &reshape_output_lyt =
|
||||
next_op->GetInputLayoutFromSWCByStrategy(next_op_conf_stra, edge->next_op_input_index());
|
||||
|
@ -150,10 +151,11 @@ void CheckConfiguredSuccEdgeConsistency(const EdgePtr edge, std::map<OperatorInf
|
|||
}
|
||||
}
|
||||
|
||||
void CheckConfiguredPrevEdgeConsistency(const EdgePtr edge, std::map<OperatorInfoPtr, StrategyPtr> configured_ops) {
|
||||
void CheckConfiguredPrevEdgeConsistency(const EdgePtr &edge,
|
||||
const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &configured_ops) {
|
||||
auto curr_op = edge->next_operator();
|
||||
auto prev_op = edge->prev_operator();
|
||||
auto prev_op_conf_stra = configured_ops[prev_op];
|
||||
auto prev_op_conf_stra = configured_ops.at(prev_op);
|
||||
if (curr_op->IsReshape()) {
|
||||
const auto &reshape_input_lyt =
|
||||
prev_op->GetOutputLayoutFromSWCByStrategy(prev_op_conf_stra, edge->prev_op_output_index());
|
||||
|
@ -171,7 +173,8 @@ void CheckConfiguredPrevEdgeConsistency(const EdgePtr edge, std::map<OperatorInf
|
|||
}
|
||||
|
||||
void CostGraph::BFS(const OperatorInfoPtr &op, const StrategyPtr &op_stra,
|
||||
std::map<OperatorInfoPtr, StrategyPtr> configured_ops, std::map<OperatorInfoPtr, bool> *visited) {
|
||||
const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> configured_ops,
|
||||
std::map<OperatorInfoPtr, bool> *visited) {
|
||||
std::queue<std::pair<std::pair<OperatorInfoPtr, std::pair<StrategyPtr, int64_t>>, int64_t>> next_level;
|
||||
(void)next_level.emplace(std::make_pair(op, std::make_pair(op_stra, -1)), 0);
|
||||
while (!next_level.empty()) {
|
||||
|
@ -188,6 +191,7 @@ void CostGraph::BFS(const OperatorInfoPtr &op, const StrategyPtr &op_stra,
|
|||
}
|
||||
for (auto &edge : curr_op->succ_edges()) {
|
||||
const auto &next_op = edge->next_operator();
|
||||
MS_LOG(DEBUG) << "forward propagation at " << curr_op->name() << "->" << next_op->name();
|
||||
if (visited->at(next_op)) {
|
||||
CheckVisitedEdgeConsistency(edge);
|
||||
continue;
|
||||
|
@ -215,6 +219,7 @@ void CostGraph::BFS(const OperatorInfoPtr &op, const StrategyPtr &op_stra,
|
|||
}
|
||||
for (auto &edge : curr_op->prev_edges()) {
|
||||
const auto &prev_op = edge->prev_operator();
|
||||
MS_LOG(DEBUG) << "backpropagation at " << curr_op->name() << "->" << prev_op->name();
|
||||
if (visited->at(prev_op)) {
|
||||
CheckVisitedEdgeConsistency(edge);
|
||||
continue;
|
||||
|
|
|
@ -52,8 +52,8 @@ class CostGraph {
|
|||
}
|
||||
void RemoveOperator(const OperatorInfoPtr &op);
|
||||
bool IsOperatorInCostGraph(const OperatorInfoPtr &op);
|
||||
void StrategyPropagate(const std::map<OperatorInfoPtr, StrategyPtr> &);
|
||||
void BFS(const OperatorInfoPtr &, const StrategyPtr &, std::map<OperatorInfoPtr, StrategyPtr>,
|
||||
void StrategyPropagate(const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &);
|
||||
void BFS(const OperatorInfoPtr &, const StrategyPtr &, const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare>,
|
||||
std::map<OperatorInfoPtr, bool> *);
|
||||
// the edge is in the form: u --> v
|
||||
void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge);
|
||||
|
|
|
@ -267,6 +267,7 @@ constexpr char SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SoftmaxCrossEntropyWithLog
|
|||
constexpr char SIGMOID_CROSS_ENTROPY_WITH_LOGITS[] = "SigmoidCrossEntropyWithLogits";
|
||||
constexpr char MATMUL[] = "MatMul";
|
||||
constexpr char GELU[] = "GeLU";
|
||||
constexpr char FAST_GELU[] = "FastGeLU";
|
||||
constexpr char TANH[] = "Tanh";
|
||||
constexpr char RECEIVE[] = "Receive";
|
||||
constexpr char SEND[] = "Send";
|
||||
|
|
|
@ -155,7 +155,7 @@ bool IsElementWiseOperator(const std::string &op_name) {
|
|||
bool IsSplittableOperator(const std::string &op_name) {
|
||||
// clang-format off
|
||||
static const std::set<std::string> splittable_op =
|
||||
{MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU,
|
||||
{MATMUL, TRANSPOSE, GELU, FAST_GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU,
|
||||
FLOORDIV, L2_NORMALIZE, ADD, MAXPOOL, AVGPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK,
|
||||
REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING,
|
||||
MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, STACK,
|
||||
|
@ -232,7 +232,7 @@ bool IsOperatorsInTwoSeparateLoops(const CNodePtr &a_cnode, const CNodePtr &b_cn
|
|||
}
|
||||
|
||||
// 'configured_stra_ops_' includes all operators that are configured sharding strategies.
|
||||
std::map<OperatorInfoPtr, StrategyPtr> configured_stra_ops_;
|
||||
std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> configured_stra_ops_;
|
||||
void InitCostGraph() {
|
||||
if (entire_costgraph == nullptr) {
|
||||
entire_costgraph = std::make_shared<CostGraph>();
|
||||
|
|
|
@ -74,6 +74,16 @@ class Strategy {
|
|||
return true;
|
||||
}
|
||||
|
||||
int64_t PartitionNum() {
|
||||
int64_t divergence = 1;
|
||||
for (size_t i = 0; i < inputs_.size(); ++i) {
|
||||
for (size_t j = 0; j < inputs_[i].size(); ++j) {
|
||||
divergence *= inputs_[i][j];
|
||||
}
|
||||
}
|
||||
return divergence;
|
||||
}
|
||||
|
||||
// Include 'another_stra' into this strategy
|
||||
void CoverStrategy(const StrategyPtr &another_stra) {
|
||||
internal_stragies_.push_back(another_stra);
|
||||
|
|
Loading…
Reference in New Issue