forked from OSSInnovation/mindspore
add a new graph operation in autoparallel
This commit is contained in:
parent
dc765dd248
commit
d24a902afe
|
@ -79,6 +79,8 @@ class StrategyWithCost {
|
|||
public:
|
||||
StrategyWithCost(StrategyPtr strategy, std::vector<TensorInfo> inputs_, std::vector<TensorInfo> outputs_)
|
||||
: strategy_ptr(std::move(strategy)), inputs_ptr(std::move(inputs_)), outputs_ptr(std::move(outputs_)) {}
|
||||
StrategyWithCost(StrategyPtr strategy, CostPtrList c_list)
|
||||
: strategy_ptr(std::move(strategy)), cost_list(std::move(c_list)) {}
|
||||
|
||||
StrategyWithCost(const StrategyWithCost &swc) = delete;
|
||||
StrategyWithCost(StrategyWithCost &&swc)
|
||||
|
@ -99,6 +101,7 @@ enum DecisionType {
|
|||
EDGE_ELIMINATION,
|
||||
MERGE_ELIMINATION,
|
||||
CONTRACT_ELIMINATION,
|
||||
SOURCE_ELIMINATION,
|
||||
TRIANGLE_ELIMINATION,
|
||||
STAR_ELIMINATION,
|
||||
FINAL_TYPE,
|
||||
|
@ -199,6 +202,38 @@ struct ContractEliminationDecision : public Decision {
|
|||
MS_DECLARE_PARENT(ContractEliminationDecision, Decision);
|
||||
};
|
||||
|
||||
/* 'SourceEliminationDecision' is for the source Elimination in DP algorithm:
|
||||
* 1 1,5
|
||||
* / \ // \\
|
||||
* / \ // \\
|
||||
* / \ // \\
|
||||
* / \ // \\
|
||||
* 2 <- 5 -> 3 ==> 2 3
|
||||
* \ / \ /
|
||||
* \ / \ /
|
||||
* \ / \ /
|
||||
* 4 4
|
||||
*
|
||||
* In the original graph, '1' has two alive outgoing edges and no incoming edges. '5' has two alive outgoing edges and
|
||||
* no incoming edges. '4' has two alive incoming edges and no outgoing edges. Source Elimination will merge '5' into
|
||||
* '1' new edges are generated to replace the old ones incident to '1' and '5'.
|
||||
*
|
||||
*/
|
||||
struct SourceEliminationDecision : public Decision {
|
||||
SourceEliminationDecision(StrategyPtr op1_stra, CostPtr op1_c, StrategyPtr op2_stra, CostPtr op2_c)
|
||||
: op1_strategy_(std::move(op1_stra)),
|
||||
op1_cost_(std::move(op1_c)),
|
||||
op2_strategy_(std::move(op2_stra)),
|
||||
op2_cost_(std::move(op2_c)) {
|
||||
type_ = DecisionType::SOURCE_ELIMINATION;
|
||||
}
|
||||
StrategyPtr op1_strategy_;
|
||||
CostPtr op1_cost_;
|
||||
StrategyPtr op2_strategy_;
|
||||
CostPtr op2_cost_;
|
||||
MS_DECLARE_PARENT(SourceEliminationDecision, Decision);
|
||||
};
|
||||
|
||||
/* 'TriangleEliminationDecision' is for the Triangle Elimination in DP algorithm:
|
||||
*
|
||||
* u
|
||||
|
@ -296,6 +331,7 @@ using OpEliminationDecisionPtr = std::shared_ptr<OpEliminationDecision>;
|
|||
using EdgeEliminationDecisionPtr = std::shared_ptr<EdgeEliminationDecision>;
|
||||
using MergeEliminationDecisionPtr = std::shared_ptr<MergeEliminationDecision>;
|
||||
using ContractEliminationDecisionPtr = std::shared_ptr<ContractEliminationDecision>;
|
||||
using SourceEliminationDecisionPtr = std::shared_ptr<SourceEliminationDecision>;
|
||||
using TriangleEliminationDecisionPtr = std::shared_ptr<TriangleEliminationDecision>;
|
||||
using StarEliminationDecisionPtr = std::shared_ptr<StarEliminationDecision>;
|
||||
using FinalDecisionPtr = std::shared_ptr<FinalDecision>;
|
||||
|
|
|
@ -42,66 +42,76 @@ Status GetStrategy(const CostGraphPtr &graph) {
|
|||
auto elimi = std::make_shared<OpElimination>(n_edge, l_edge, node, r_edge);
|
||||
eliminations.emplace_back(std::move(elimi));
|
||||
}
|
||||
auto edges = graph->CheckEdgeElimination();
|
||||
if ((!flag) && (!edges.empty())) {
|
||||
// Applying the Edge Elimination
|
||||
flag = true;
|
||||
auto n_edge = graph->EliminationEdges(edges);
|
||||
auto elimi = std::make_shared<EdgeElimination>(n_edge, edges);
|
||||
eliminations.emplace_back(std::move(elimi));
|
||||
if (!flag) {
|
||||
auto edges = graph->CheckEdgeElimination();
|
||||
if (!edges.empty()) {
|
||||
// Applying the Edge Elimination
|
||||
flag = true;
|
||||
auto n_edge = graph->EliminationEdges(edges);
|
||||
auto elimi = std::make_shared<EdgeElimination>(n_edge, edges);
|
||||
eliminations.emplace_back(std::move(elimi));
|
||||
}
|
||||
}
|
||||
auto merge_node = graph->CheckMergeElimination();
|
||||
if ((!flag) && (merge_node != nullptr)) {
|
||||
// Applying the Merge Elimination
|
||||
flag = true;
|
||||
auto succ_edge = merge_node->GetAliveSuccEdges()[0];
|
||||
auto target_node = graph->EliminationMerge(merge_node);
|
||||
auto elimi = std::make_shared<MergeElimination>(merge_node, succ_edge, target_node);
|
||||
eliminations.emplace_back(std::move(elimi));
|
||||
if (!flag) {
|
||||
auto merge_node = graph->CheckMergeElimination();
|
||||
if (merge_node != nullptr) {
|
||||
// Applying the Merge Elimination
|
||||
flag = true;
|
||||
auto succ_edge = merge_node->GetAliveSuccEdges()[0];
|
||||
auto target_node = graph->EliminationMerge(merge_node);
|
||||
auto elimi = std::make_shared<MergeElimination>(merge_node, succ_edge, target_node);
|
||||
eliminations.emplace_back(std::move(elimi));
|
||||
}
|
||||
}
|
||||
auto contracted_node = graph->CheckContractElimination();
|
||||
if ((!flag) && (contracted_node != nullptr)) {
|
||||
// Applying the Contract Elimination
|
||||
flag = true;
|
||||
auto prev_edge = contracted_node->GetAlivePrevEdges()[0];
|
||||
auto target_node = graph->EliminationContract(contracted_node);
|
||||
auto elimi = std::make_shared<ContractElimination>(target_node, prev_edge, contracted_node);
|
||||
eliminations.emplace_back(std::move(elimi));
|
||||
if (!flag) {
|
||||
auto contracted_node = graph->CheckContractElimination();
|
||||
if ((contracted_node != nullptr)) {
|
||||
// Applying the Contract Elimination
|
||||
flag = true;
|
||||
auto prev_edge = contracted_node->GetAlivePrevEdges()[0];
|
||||
auto target_node = graph->EliminationContract(contracted_node);
|
||||
auto elimi = std::make_shared<ContractElimination>(target_node, prev_edge, contracted_node);
|
||||
eliminations.emplace_back(std::move(elimi));
|
||||
}
|
||||
}
|
||||
auto triangle_pair = graph->CheckTriangleElimination();
|
||||
if ((!flag) && (triangle_pair.first != nullptr)) {
|
||||
// Applying the Triangle Elimination
|
||||
flag = true;
|
||||
auto eliminated_node = triangle_pair.first;
|
||||
auto l_r_edge = triangle_pair.second;
|
||||
if (!flag) {
|
||||
auto triangle_pair = graph->CheckTriangleElimination();
|
||||
if (triangle_pair.first != nullptr) {
|
||||
// Applying the Triangle Elimination
|
||||
flag = true;
|
||||
auto eliminated_node = triangle_pair.first;
|
||||
auto l_r_edge = triangle_pair.second;
|
||||
|
||||
auto left_node = l_r_edge->prev_operator();
|
||||
auto left_edge = eliminated_node->GetAliveSuccEdges()[0];
|
||||
auto right_edge = eliminated_node->GetAliveSuccEdges()[1];
|
||||
MS_EXCEPTION_IF_NULL(left_edge);
|
||||
if (left_edge->next_operator() != left_node) {
|
||||
auto tmp = left_edge;
|
||||
left_edge = right_edge;
|
||||
right_edge = tmp;
|
||||
auto left_node = l_r_edge->prev_operator();
|
||||
auto left_edge = eliminated_node->GetAliveSuccEdges()[0];
|
||||
auto right_edge = eliminated_node->GetAliveSuccEdges()[1];
|
||||
MS_EXCEPTION_IF_NULL(left_edge);
|
||||
if (left_edge->next_operator() != left_node) {
|
||||
auto tmp = left_edge;
|
||||
left_edge = right_edge;
|
||||
right_edge = tmp;
|
||||
}
|
||||
auto left_node_cpy = graph->EliminationTriangle(eliminated_node, l_r_edge);
|
||||
auto right_node = l_r_edge->next_operator();
|
||||
auto elimi =
|
||||
std::make_shared<TriangleElimination>(eliminated_node, left_edge, left_node_cpy, right_edge, right_node);
|
||||
eliminations.emplace_back(std::move(elimi));
|
||||
}
|
||||
auto left_node_cpy = graph->EliminationTriangle(eliminated_node, l_r_edge);
|
||||
auto right_node = l_r_edge->next_operator();
|
||||
auto elimi =
|
||||
std::make_shared<TriangleElimination>(eliminated_node, left_edge, left_node_cpy, right_edge, right_node);
|
||||
eliminations.emplace_back(std::move(elimi));
|
||||
}
|
||||
auto star_center = graph->CheckStarElimination();
|
||||
if ((!flag) && (star_center != nullptr)) {
|
||||
// Applying the Star Elimination
|
||||
flag = true;
|
||||
auto succ_edges = graph->EliminationStar(star_center);
|
||||
std::vector<OperatorInfoPtr> succ_nodes;
|
||||
for (size_t i = 0; i < succ_edges.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(succ_edges[i]);
|
||||
succ_nodes.push_back(succ_edges[i]->next_operator());
|
||||
if (!flag) {
|
||||
auto star_center = graph->CheckStarElimination();
|
||||
if (star_center != nullptr) {
|
||||
// Applying the Star Elimination
|
||||
flag = true;
|
||||
auto succ_edges = graph->EliminationStar(star_center);
|
||||
std::vector<OperatorInfoPtr> succ_nodes;
|
||||
for (size_t i = 0; i < succ_edges.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(succ_edges[i]);
|
||||
succ_nodes.push_back(succ_edges[i]->next_operator());
|
||||
}
|
||||
auto elimi = std::make_shared<StarElimination>(star_center, succ_edges, succ_nodes);
|
||||
eliminations.emplace_back(std::move(elimi));
|
||||
}
|
||||
auto elimi = std::make_shared<StarElimination>(star_center, succ_edges, succ_nodes);
|
||||
eliminations.emplace_back(std::move(elimi));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ namespace parallel {
|
|||
// the operators' strategies can be all determined.
|
||||
|
||||
struct Elimination : public Base {
|
||||
enum EliminationType { OPERA, EDGE, MERGE, CONTRACT, TRIANGLE, STAR };
|
||||
enum EliminationType { OPERA, EDGE, MERGE, CONTRACT, SOURCE, TRIANGLE, STAR };
|
||||
Elimination(EdgePtr n_edge, EliminationType ty) : new_edge_(std::move(n_edge)), type_(ty) {}
|
||||
|
||||
EdgePtr new_edge_;
|
||||
|
@ -100,6 +100,26 @@ struct ContractElimination : public Elimination {
|
|||
MS_DECLARE_PARENT(ContractElimination, Elimination);
|
||||
};
|
||||
|
||||
// Source Elimination
|
||||
struct SourceElimination : public Elimination {
|
||||
SourceElimination(OperatorInfoPtr p_source, std::vector<EdgePtr> p_succ_edges, std::vector<EdgePtr> p_new_succ_edges,
|
||||
OperatorInfoPtr s_source, std::vector<EdgePtr> s_succ_edges, std::vector<EdgePtr> s_new_succ_edges)
|
||||
: Elimination(nullptr, Elimination::EliminationType::SOURCE),
|
||||
primary_source_(std::move(p_source)),
|
||||
primary_succ_edges_(std::move(p_succ_edges)),
|
||||
primary_new_succ_edges_(std::move(p_new_succ_edges)),
|
||||
secondary_source_(std::move(s_source)),
|
||||
secondary_succ_edges_(std::move(s_succ_edges)),
|
||||
secondary_new_succ_edges_(std::move(s_new_succ_edges)) {}
|
||||
OperatorInfoPtr primary_source_;
|
||||
std::vector<EdgePtr> primary_succ_edges_;
|
||||
std::vector<EdgePtr> primary_new_succ_edges_;
|
||||
OperatorInfoPtr secondary_source_;
|
||||
std::vector<EdgePtr> secondary_succ_edges_;
|
||||
std::vector<EdgePtr> secondary_new_succ_edges_;
|
||||
MS_DECLARE_PARENT(SourceElimination, Elimination);
|
||||
};
|
||||
|
||||
// Triangle Elimination
|
||||
struct TriangleElimination : public Elimination {
|
||||
TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge,
|
||||
|
@ -138,6 +158,7 @@ using OpEliminationPtr = std::shared_ptr<OpElimination>;
|
|||
using EdgeEliminationPtr = std::shared_ptr<EdgeElimination>;
|
||||
using MergeEliminationPtr = std::shared_ptr<MergeElimination>;
|
||||
using ContractEliminationPtr = std::shared_ptr<ContractElimination>;
|
||||
using SourceEliminationPtr = std::shared_ptr<SourceElimination>;
|
||||
using TriangleEliminationPtr = std::shared_ptr<TriangleElimination>;
|
||||
using StarEliminationPtr = std::shared_ptr<StarElimination>;
|
||||
|
||||
|
|
|
@ -320,5 +320,17 @@ Status Edge::CalculateMemoryCostForInference() {
|
|||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
void Edge::SetCostMapAndInputOutput(std::map<CostPtrKey, CostPtrList> &cost_map) {
|
||||
cost_map_ = cost_map;
|
||||
pre_op_output_.clear();
|
||||
next_op_input_.clear();
|
||||
|
||||
for (auto &key_value : cost_map_) {
|
||||
auto &key_pair = key_value.first;
|
||||
pre_op_output_.emplace_back(std::pair<StrategyPtr, std::vector<TensorInfo>>(key_pair.first, {}));
|
||||
next_op_input_.emplace_back(std::pair<StrategyPtr, std::vector<TensorInfo>>(key_pair.second, {}));
|
||||
}
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -80,6 +80,8 @@ class Edge {
|
|||
std::string edge_name() const { return edge_name_; }
|
||||
// Init cost_map_: for each output layout and input layout, calculate the cost
|
||||
Status InitEdgeCost();
|
||||
std::map<CostPtrKey, CostPtrList> GetCostMap() { return cost_map_; }
|
||||
void SetCostMapAndInputOutput(std::map<CostPtrKey, CostPtrList> &);
|
||||
// For two operators u--->v, given the output tensor layout of u,
|
||||
// and the input tensor layout of v, return the redistribution cost,
|
||||
// and the op_list to carry out the redistribution.
|
||||
|
|
|
@ -794,6 +794,191 @@ OperatorInfoPtr CostGraph::CheckContractElimination() const {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
std::pair<OperatorInfoPtr, OperatorInfoPtr> CostGraph::CheckSourceElimination() const {
|
||||
size_t source_count = 0;
|
||||
std::vector<OperatorInfoPtr> op_vector(2, nullptr);
|
||||
for (auto &op : ops_) {
|
||||
MS_EXCEPTION_IF_NULL(op);
|
||||
bool bool_test = op->is_alive() && op->GetAlivePrevEdges().empty() && op->GetAliveSuccEdges().size() > 0;
|
||||
if (bool_test) {
|
||||
op_vector[source_count++] = op;
|
||||
if (source_count == 2) {
|
||||
return std::make_pair(op_vector[0], op_vector[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_pair(nullptr, nullptr);
|
||||
}
|
||||
|
||||
void CostGraph::CreateSourceEliminationSubCostList(StrategyPtr op1_old_stra, const CostPtrList &op1_old_clist,
|
||||
StrategyPtr op2_old_stra, const CostPtrList &op2_old_clist,
|
||||
CostPtrList *op1_new_clist) {
|
||||
for (auto &op1_cost : op1_old_clist) {
|
||||
for (auto &op2_cost : op2_old_clist) {
|
||||
double computation = op1_cost->computation_cost_ + op2_cost->computation_cost_;
|
||||
double memory = op1_cost->memory_with_reuse_ + op2_cost->memory_with_reuse_;
|
||||
double communication = op1_cost->communication_cost_ + op2_cost->communication_cost_;
|
||||
double communication_forward = op1_cost->communication_forward_ + op2_cost->communication_forward_;
|
||||
double communication_without_para =
|
||||
op1_cost->communication_without_parameter_ + op2_cost->communication_without_parameter_;
|
||||
auto decision = std::make_shared<SourceEliminationDecision>(op1_old_stra, op1_cost, op2_old_stra, op2_cost);
|
||||
auto new_cost = std::make_shared<Cost>(computation, communication, decision);
|
||||
MS_EXCEPTION_IF_NULL(new_cost);
|
||||
new_cost->communication_without_parameter_ = communication_without_para;
|
||||
new_cost->communication_with_partial_para_ =
|
||||
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
|
||||
new_cost->memory_with_reuse_ = memory;
|
||||
new_cost->communication_forward_ = communication_forward;
|
||||
MS_EXCEPTION_IF_NULL(op1_new_clist);
|
||||
op1_new_clist->emplace_back(std::move(new_cost));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<std::vector<std::shared_ptr<Edge>>, std::vector<std::shared_ptr<Edge>>> CostGraph::EliminationSources(
|
||||
OperatorInfoPtr op1, OperatorInfoPtr op2) {
|
||||
MS_EXCEPTION_IF_NULL(op1);
|
||||
MS_EXCEPTION_IF_NULL(op2);
|
||||
MS_LOG(INFO) << "Now source eliminating node: " << op2->name() << " to node: " << op1->name();
|
||||
|
||||
auto op1_old_succ_edges = op1->GetAliveSuccEdges();
|
||||
std::vector<std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>>> op1_edges_reorganised_cost(
|
||||
op1_old_succ_edges.size());
|
||||
std::vector<std::map<CostPtrKey, CostPtrList>> op1_new_edges_cost(op1_old_succ_edges.size());
|
||||
std::vector<std::shared_ptr<Edge>> op1_new_succ_edges(op1_old_succ_edges.size());
|
||||
|
||||
auto op2_old_succ_edges = op2->GetAliveSuccEdges();
|
||||
std::vector<std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>>> op2_edges_reorganised_cost(
|
||||
op2_old_succ_edges.size());
|
||||
std::vector<std::map<CostPtrKey, CostPtrList>> op2_new_edges_cost(op2_old_succ_edges.size());
|
||||
std::vector<std::shared_ptr<Edge>> op2_new_succ_edges(op2_old_succ_edges.size());
|
||||
|
||||
// Construct cost_map for the data_structure of 'op1_edges_reorganised_cost' and 'op2_edges_reorganised_cost'
|
||||
for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) {
|
||||
const auto &op1_cost_map = op1_old_succ_edges[i]->GetCostMap();
|
||||
std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>> from_tocost;
|
||||
for (const auto &key_value : op1_cost_map) {
|
||||
const auto &from_to_strategies = key_value.first;
|
||||
const auto &costlist = key_value.second;
|
||||
from_tocost[from_to_strategies.first].emplace_back(std::make_pair(from_to_strategies.second, costlist));
|
||||
}
|
||||
op1_edges_reorganised_cost[i] = from_tocost;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) {
|
||||
const auto &op2_cost_map = op2_old_succ_edges[i]->GetCostMap();
|
||||
std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>> from_tocost;
|
||||
for (const auto &key_value : op2_cost_map) {
|
||||
const auto &from_to_strategies = key_value.first;
|
||||
const auto &costlist = key_value.second;
|
||||
from_tocost[from_to_strategies.first].emplace_back(std::make_pair(from_to_strategies.second, costlist));
|
||||
}
|
||||
op2_edges_reorganised_cost[i] = from_tocost;
|
||||
}
|
||||
|
||||
// Merge op2 into op1
|
||||
const auto &op1_old_stra_cost = op1->GetStrategyCost();
|
||||
const auto &op2_old_stra_cost = op2->GetStrategyCost();
|
||||
std::vector<std::shared_ptr<StrategyWithCost>> op1_new_stra_cost;
|
||||
|
||||
for (auto &op1_stra_cost : op1_old_stra_cost) {
|
||||
auto op1_old_stra = op1_stra_cost->strategy_ptr;
|
||||
auto op1_old_costlist = op1_stra_cost->cost_list;
|
||||
|
||||
for (auto &op2_stra_cost : op2_old_stra_cost) {
|
||||
auto op2_stra = op2_stra_cost->strategy_ptr;
|
||||
auto op2_costlist = op2_stra_cost->cost_list;
|
||||
|
||||
StrategyPtr op1_new_stra = std::make_shared<Strategy>(*op1_old_stra);
|
||||
op1_new_stra->CoverStrategy(op2_stra);
|
||||
CostPtrList op1_new_costlist;
|
||||
// Calculate new cost for 'op1_new_costlist'
|
||||
CreateSourceEliminationSubCostList(op1_old_stra, op1_old_costlist, op2_stra, op2_costlist, &op1_new_costlist);
|
||||
std::shared_ptr<StrategyWithCost> swc = std::make_shared<StrategyWithCost>(op1_new_stra, op1_new_costlist);
|
||||
op1_new_stra_cost.emplace_back(swc);
|
||||
|
||||
// Set cost for new successive edges of op1 and op2
|
||||
for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) {
|
||||
auto &from_tocost = op1_edges_reorganised_cost[i];
|
||||
auto &to_cost = from_tocost[op1_old_stra];
|
||||
auto &new_cost_map = op1_new_edges_cost[i];
|
||||
for (auto &stra_costlit : to_cost) {
|
||||
auto &to_strategy = stra_costlit.first;
|
||||
auto &edge_costlist = stra_costlit.second;
|
||||
CostPtrKey new_key = {op1_new_stra, to_strategy};
|
||||
new_cost_map[new_key] = edge_costlist;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) {
|
||||
auto &from_tocost = op2_edges_reorganised_cost[i];
|
||||
auto &to_cost = from_tocost[op2_stra];
|
||||
auto &new_cost_map = op2_new_edges_cost[i];
|
||||
for (auto &stra_costlist : to_cost) {
|
||||
auto &to_strategy = stra_costlist.first;
|
||||
auto &edge_costlist = stra_costlist.second;
|
||||
CostPtrKey new_key = {op1_new_stra, to_strategy};
|
||||
new_cost_map[new_key] = edge_costlist;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
op1->SetStrategyCost(op1_new_stra_cost);
|
||||
op2->SetNotAlive();
|
||||
|
||||
// Update the edges incident to op1, and edges incident to op2
|
||||
for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) {
|
||||
auto &new_cost_map = op1_new_edges_cost[i];
|
||||
auto &ith_edge = op1_old_succ_edges[i];
|
||||
|
||||
std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + ith_edge->next_operator()->name();
|
||||
std::shared_ptr<Edge> new_edge;
|
||||
if (ith_edge->is_combined()) {
|
||||
std::vector<size_t> output_indexs, input_indexs;
|
||||
output_indexs = ith_edge->prev_op_output_indexs();
|
||||
input_indexs = ith_edge->next_op_input_indexs();
|
||||
new_edge =
|
||||
std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_indexs, input_indexs, true);
|
||||
} else {
|
||||
size_t output_index, input_index;
|
||||
output_index = ith_edge->prev_op_output_index();
|
||||
input_index = ith_edge->next_op_input_index();
|
||||
new_edge =
|
||||
std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_index, input_index, false);
|
||||
}
|
||||
new_edge->SetCostMapAndInputOutput(new_cost_map);
|
||||
// replace the old successive edges with the new ones.
|
||||
op1->ReplaceSuccEdge(ith_edge->next_operator(), new_edge);
|
||||
ith_edge->next_operator()->ReplacePreEdge(op1, new_edge);
|
||||
op1_new_succ_edges[i] = new_edge;
|
||||
}
|
||||
for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) {
|
||||
auto &new_cost_map = op2_new_edges_cost[i];
|
||||
auto &ith_edge = op2_old_succ_edges[i];
|
||||
const auto &destination = ith_edge->next_operator();
|
||||
|
||||
std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + destination->name();
|
||||
std::shared_ptr<Edge> new_edge;
|
||||
if (ith_edge->is_combined()) {
|
||||
std::vector<size_t> output_indexs, input_indexs;
|
||||
output_indexs = ith_edge->prev_op_output_indexs();
|
||||
input_indexs = ith_edge->next_op_input_indexs();
|
||||
new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_indexs, input_indexs, true);
|
||||
} else {
|
||||
size_t output_index, input_index;
|
||||
output_index = ith_edge->prev_op_output_index();
|
||||
input_index = ith_edge->next_op_input_index();
|
||||
new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_index, input_index, false);
|
||||
}
|
||||
new_edge->SetCostMapAndInputOutput(new_cost_map);
|
||||
// replace the old successive edges with the new ones.
|
||||
destination->ReplacePreEdge(op2, new_edge);
|
||||
op1->AddSuccEdge(new_edge);
|
||||
op2_new_succ_edges[i] = new_edge;
|
||||
}
|
||||
MS_LOG(INFO) << "Source eliminating node: " << op2->name() << " to node: " << op1->name() + " succeeded.";
|
||||
return {op1_new_succ_edges, op2_new_succ_edges};
|
||||
}
|
||||
|
||||
// Check the graph whether a TriangleElimination can be performed
|
||||
std::pair<OperatorInfoPtr, std::shared_ptr<Edge>> CostGraph::CheckTriangleElimination() const {
|
||||
for (auto &op : ops_) {
|
||||
|
|
|
@ -180,6 +180,14 @@ class CostGraph {
|
|||
void CreateStarEliminationSubCostList(const StrategyPtr &, const CostPtrList &, const CostPtrList &,
|
||||
const StrategyPtr &, const CostPtrList &, std::vector<StrategyPtr>,
|
||||
CostPtrList &, CostPtrList &, CostPtrList *);
|
||||
// Return <op1, op2>. we merge 'op2' into 'op1'
|
||||
std::pair<OperatorInfoPtr, OperatorInfoPtr> CheckSourceElimination() const;
|
||||
void CreateSourceEliminationSubCostList(StrategyPtr, const CostPtrList &, StrategyPtr, const CostPtrList &,
|
||||
CostPtrList *);
|
||||
// We merge 'op2' into op1. The returned value are '<Edges1, Edges2>'. 'Edges1' are newly updated edges for 'op1',
|
||||
// 'Edges2' are newly updated edges for 'op2'.
|
||||
std::pair<std::vector<std::shared_ptr<Edge>>, std::vector<std::shared_ptr<Edge>>> EliminationSources(
|
||||
OperatorInfoPtr op1, OperatorInfoPtr op2);
|
||||
// Calculate memory cost for training phase or inference phase.
|
||||
Status CalculateMemoryCost();
|
||||
// When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then
|
||||
|
|
|
@ -1330,5 +1330,9 @@ void OperatorInfo::CheckSelectedStrategy(const StrategyPtr &s_strategy) {
|
|||
PrintStrategy(s_strategy);
|
||||
}
|
||||
}
|
||||
|
||||
void OperatorInfo::SetStrategyCost(const std::vector<std::shared_ptr<StrategyWithCost>> &stra_cost) {
|
||||
strategy_cost_ = stra_cost;
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -97,6 +97,7 @@ class OperatorInfo {
|
|||
// is checked
|
||||
Status SetCostUnderStrategyBase(const StrategyPtr &strategy);
|
||||
std::vector<std::shared_ptr<StrategyWithCost>> GetStrategyCost() { return strategy_cost_; }
|
||||
void SetStrategyCost(const std::vector<std::shared_ptr<StrategyWithCost>> &);
|
||||
// In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving
|
||||
// WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory
|
||||
// at the end of forward phase.
|
||||
|
|
|
@ -36,7 +36,19 @@ using StrategyPtr = std::shared_ptr<Strategy>;
|
|||
|
||||
class Strategy {
|
||||
public:
|
||||
Strategy(int32_t stage, std::vector<Dimensions> inputs) : stage_(stage), inputs_(std::move(inputs)) {}
|
||||
Strategy(int32_t stage, std::vector<Dimensions> inputs)
|
||||
: stage_(stage), inputs_(std::move(inputs)), internal_size_(0), internal_stragies_() {}
|
||||
|
||||
Strategy(const Strategy &another_stra) : stage_(another_stra.GetInputStage()) {
|
||||
inputs_ = another_stra.GetInputDim();
|
||||
internal_size_ = another_stra.GetInternalSize();
|
||||
if (internal_size_ != 0) {
|
||||
internal_stragies_ = another_stra.GetInternalStrategies();
|
||||
} else {
|
||||
internal_stragies_ = {};
|
||||
}
|
||||
}
|
||||
|
||||
~Strategy() = default;
|
||||
size_t GetInputNumber() const { return inputs_.size(); }
|
||||
std::vector<Dimensions> GetInputDim() const { return inputs_; }
|
||||
|
@ -47,7 +59,10 @@ class Strategy {
|
|||
}
|
||||
}
|
||||
void ResetInputs(const std::vector<Dimensions> &input) { inputs_ = input; }
|
||||
std::vector<StrategyPtr> GetInternalStrategies() const { return internal_stragies_; }
|
||||
size_t GetInternalSize() const { return internal_size_; }
|
||||
|
||||
// TODO(Xiaoda): need fix for adapting 'CoverStrategy'
|
||||
bool IsEqual(const StrategyPtr &another_stra) {
|
||||
if (another_stra == nullptr) {
|
||||
return false;
|
||||
|
@ -58,11 +73,19 @@ class Strategy {
|
|||
return true;
|
||||
}
|
||||
|
||||
// Include 'another_stra' into this strategy
|
||||
void CoverStrategy(const StrategyPtr &another_stra) {
|
||||
internal_stragies_.push_back(another_stra);
|
||||
internal_size_++;
|
||||
}
|
||||
|
||||
private:
|
||||
const int32_t stage_;
|
||||
|
||||
// The size of Dimensions must equal to inputs_ tensor dimension.
|
||||
std::vector<Dimensions> inputs_;
|
||||
size_t internal_size_ = 0;
|
||||
std::vector<StrategyPtr> internal_stragies_;
|
||||
};
|
||||
|
||||
inline StrategyPtr NewStrategy(const int32_t stage, const std::vector<Dimensions> &inputs) {
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y, z, w, a):
|
||||
predict = self.network(x, y, z, w, a)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y, z, w, a):
|
||||
return C.grad_all(self.network)(x, y, z, w, a)
|
||||
|
||||
# model_parallel test
|
||||
|
||||
|
||||
def test_double_source_graph():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.matmul1 = P.MatMul()
|
||||
self.matmul2 = P.MatMul()
|
||||
self.matmul3 = P.MatMul()
|
||||
self.matmul4 = P.MatMul()
|
||||
self.matmul5 = P.MatMul()
|
||||
|
||||
def construct(self, x, y, z, w, a):
|
||||
m1_result = self.matmul1(x, y)
|
||||
m2_result = self.matmul2(z, w)
|
||||
m3_result = self.matmul3(m2_result, m1_result)
|
||||
m4_result = self.matmul4(m2_result, m1_result)
|
||||
out = self.matmul5(m3_result, m4_result)
|
||||
|
||||
return out
|
||||
|
||||
size = 8
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
x = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
z = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
w = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
a = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, z, w, a)
|
||||
|
||||
|
||||
def test_double_source_complex_graph():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.matmul1 = P.MatMul()
|
||||
self.matmul2 = P.MatMul()
|
||||
self.matmul3 = P.MatMul()
|
||||
self.matmul4 = P.MatMul()
|
||||
self.matmul5 = P.MatMul()
|
||||
self.matmul6 = P.MatMul()
|
||||
|
||||
def construct(self, x, y, z, w, a):
|
||||
m1_result = self.matmul1(x, y)
|
||||
m6_result = self.matmul6(m1_result, a)
|
||||
m2_result = self.matmul2(z, w)
|
||||
m3_result = self.matmul3(m2_result, m6_result)
|
||||
m4_result = self.matmul4(m2_result, m1_result)
|
||||
out = self.matmul5(m3_result, m4_result)
|
||||
|
||||
return out
|
||||
|
||||
size = 8
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
x = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
z = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
w = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
a = Tensor(np.ones([32, 32]), dtype=ms.float32)
|
||||
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y, z, w, a)
|
Loading…
Reference in New Issue