add a new graph operation in autoparallel

This commit is contained in:
Xiaoda Zhang 2020-07-28 09:37:51 +08:00
parent dc765dd248
commit d24a902afe
11 changed files with 471 additions and 55 deletions

View File

@ -79,6 +79,8 @@ class StrategyWithCost {
public:
StrategyWithCost(StrategyPtr strategy, std::vector<TensorInfo> inputs_, std::vector<TensorInfo> outputs_)
: strategy_ptr(std::move(strategy)), inputs_ptr(std::move(inputs_)), outputs_ptr(std::move(outputs_)) {}
StrategyWithCost(StrategyPtr strategy, CostPtrList c_list)
: strategy_ptr(std::move(strategy)), cost_list(std::move(c_list)) {}
StrategyWithCost(const StrategyWithCost &swc) = delete;
StrategyWithCost(StrategyWithCost &&swc)
@ -99,6 +101,7 @@ enum DecisionType {
EDGE_ELIMINATION,
MERGE_ELIMINATION,
CONTRACT_ELIMINATION,
SOURCE_ELIMINATION,
TRIANGLE_ELIMINATION,
STAR_ELIMINATION,
FINAL_TYPE,
@ -199,6 +202,38 @@ struct ContractEliminationDecision : public Decision {
MS_DECLARE_PARENT(ContractEliminationDecision, Decision);
};
/* 'SourceEliminationDecision' is for the source Elimination in DP algorithm:
* 1 1,5
* / \ // \\
* / \ // \\
* / \ // \\
* / \ // \\
* 2 <- 5 -> 3 ==> 2 3
* \ / \ /
* \ / \ /
* \ / \ /
* 4 4
*
* In the original graph, '1' has two alive outgoing edges and no incoming edges. '5' has two alive outgoing edges and
* no incoming edges. '4' has two alive incoming edges and no outgoing edges. Source Elimination will merge '5' into
* '1' new edges are generated to replace the old ones incident to '1' and '5'.
*
*/
struct SourceEliminationDecision : public Decision {
SourceEliminationDecision(StrategyPtr op1_stra, CostPtr op1_c, StrategyPtr op2_stra, CostPtr op2_c)
: op1_strategy_(std::move(op1_stra)),
op1_cost_(std::move(op1_c)),
op2_strategy_(std::move(op2_stra)),
op2_cost_(std::move(op2_c)) {
type_ = DecisionType::SOURCE_ELIMINATION;
}
StrategyPtr op1_strategy_;
CostPtr op1_cost_;
StrategyPtr op2_strategy_;
CostPtr op2_cost_;
MS_DECLARE_PARENT(SourceEliminationDecision, Decision);
};
/* 'TriangleEliminationDecision' is for the Triangle Elimination in DP algorithm:
*
* u
@ -296,6 +331,7 @@ using OpEliminationDecisionPtr = std::shared_ptr<OpEliminationDecision>;
using EdgeEliminationDecisionPtr = std::shared_ptr<EdgeEliminationDecision>;
using MergeEliminationDecisionPtr = std::shared_ptr<MergeEliminationDecision>;
using ContractEliminationDecisionPtr = std::shared_ptr<ContractEliminationDecision>;
using SourceEliminationDecisionPtr = std::shared_ptr<SourceEliminationDecision>;
using TriangleEliminationDecisionPtr = std::shared_ptr<TriangleEliminationDecision>;
using StarEliminationDecisionPtr = std::shared_ptr<StarEliminationDecision>;
using FinalDecisionPtr = std::shared_ptr<FinalDecision>;

View File

@ -42,16 +42,19 @@ Status GetStrategy(const CostGraphPtr &graph) {
auto elimi = std::make_shared<OpElimination>(n_edge, l_edge, node, r_edge);
eliminations.emplace_back(std::move(elimi));
}
if (!flag) {
auto edges = graph->CheckEdgeElimination();
if ((!flag) && (!edges.empty())) {
if (!edges.empty()) {
// Applying the Edge Elimination
flag = true;
auto n_edge = graph->EliminationEdges(edges);
auto elimi = std::make_shared<EdgeElimination>(n_edge, edges);
eliminations.emplace_back(std::move(elimi));
}
}
if (!flag) {
auto merge_node = graph->CheckMergeElimination();
if ((!flag) && (merge_node != nullptr)) {
if (merge_node != nullptr) {
// Applying the Merge Elimination
flag = true;
auto succ_edge = merge_node->GetAliveSuccEdges()[0];
@ -59,8 +62,10 @@ Status GetStrategy(const CostGraphPtr &graph) {
auto elimi = std::make_shared<MergeElimination>(merge_node, succ_edge, target_node);
eliminations.emplace_back(std::move(elimi));
}
}
if (!flag) {
auto contracted_node = graph->CheckContractElimination();
if ((!flag) && (contracted_node != nullptr)) {
if ((contracted_node != nullptr)) {
// Applying the Contract Elimination
flag = true;
auto prev_edge = contracted_node->GetAlivePrevEdges()[0];
@ -68,8 +73,10 @@ Status GetStrategy(const CostGraphPtr &graph) {
auto elimi = std::make_shared<ContractElimination>(target_node, prev_edge, contracted_node);
eliminations.emplace_back(std::move(elimi));
}
}
if (!flag) {
auto triangle_pair = graph->CheckTriangleElimination();
if ((!flag) && (triangle_pair.first != nullptr)) {
if (triangle_pair.first != nullptr) {
// Applying the Triangle Elimination
flag = true;
auto eliminated_node = triangle_pair.first;
@ -90,8 +97,10 @@ Status GetStrategy(const CostGraphPtr &graph) {
std::make_shared<TriangleElimination>(eliminated_node, left_edge, left_node_cpy, right_edge, right_node);
eliminations.emplace_back(std::move(elimi));
}
}
if (!flag) {
auto star_center = graph->CheckStarElimination();
if ((!flag) && (star_center != nullptr)) {
if (star_center != nullptr) {
// Applying the Star Elimination
flag = true;
auto succ_edges = graph->EliminationStar(star_center);
@ -104,6 +113,7 @@ Status GetStrategy(const CostGraphPtr &graph) {
eliminations.emplace_back(std::move(elimi));
}
}
}
// Phase 2: Search the cost_list in the final graph, and determine the optimal one
if (graph->SearchStrategy() != SUCCESS) {

View File

@ -42,7 +42,7 @@ namespace parallel {
// the operators' strategies can be all determined.
struct Elimination : public Base {
enum EliminationType { OPERA, EDGE, MERGE, CONTRACT, TRIANGLE, STAR };
enum EliminationType { OPERA, EDGE, MERGE, CONTRACT, SOURCE, TRIANGLE, STAR };
Elimination(EdgePtr n_edge, EliminationType ty) : new_edge_(std::move(n_edge)), type_(ty) {}
EdgePtr new_edge_;
@ -100,6 +100,26 @@ struct ContractElimination : public Elimination {
MS_DECLARE_PARENT(ContractElimination, Elimination);
};
// Source Elimination
struct SourceElimination : public Elimination {
SourceElimination(OperatorInfoPtr p_source, std::vector<EdgePtr> p_succ_edges, std::vector<EdgePtr> p_new_succ_edges,
OperatorInfoPtr s_source, std::vector<EdgePtr> s_succ_edges, std::vector<EdgePtr> s_new_succ_edges)
: Elimination(nullptr, Elimination::EliminationType::SOURCE),
primary_source_(std::move(p_source)),
primary_succ_edges_(std::move(p_succ_edges)),
primary_new_succ_edges_(std::move(p_new_succ_edges)),
secondary_source_(std::move(s_source)),
secondary_succ_edges_(std::move(s_succ_edges)),
secondary_new_succ_edges_(std::move(s_new_succ_edges)) {}
OperatorInfoPtr primary_source_;
std::vector<EdgePtr> primary_succ_edges_;
std::vector<EdgePtr> primary_new_succ_edges_;
OperatorInfoPtr secondary_source_;
std::vector<EdgePtr> secondary_succ_edges_;
std::vector<EdgePtr> secondary_new_succ_edges_;
MS_DECLARE_PARENT(SourceElimination, Elimination);
};
// Triangle Elimination
struct TriangleElimination : public Elimination {
TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge,
@ -138,6 +158,7 @@ using OpEliminationPtr = std::shared_ptr<OpElimination>;
using EdgeEliminationPtr = std::shared_ptr<EdgeElimination>;
using MergeEliminationPtr = std::shared_ptr<MergeElimination>;
using ContractEliminationPtr = std::shared_ptr<ContractElimination>;
using SourceEliminationPtr = std::shared_ptr<SourceElimination>;
using TriangleEliminationPtr = std::shared_ptr<TriangleElimination>;
using StarEliminationPtr = std::shared_ptr<StarElimination>;

View File

@ -320,5 +320,17 @@ Status Edge::CalculateMemoryCostForInference() {
}
return SUCCESS;
}
void Edge::SetCostMapAndInputOutput(std::map<CostPtrKey, CostPtrList> &cost_map) {
cost_map_ = cost_map;
pre_op_output_.clear();
next_op_input_.clear();
for (auto &key_value : cost_map_) {
auto &key_pair = key_value.first;
pre_op_output_.emplace_back(std::pair<StrategyPtr, std::vector<TensorInfo>>(key_pair.first, {}));
next_op_input_.emplace_back(std::pair<StrategyPtr, std::vector<TensorInfo>>(key_pair.second, {}));
}
}
} // namespace parallel
} // namespace mindspore

View File

@ -80,6 +80,8 @@ class Edge {
std::string edge_name() const { return edge_name_; }
// Init cost_map_: for each output layout and input layout, calculate the cost
Status InitEdgeCost();
std::map<CostPtrKey, CostPtrList> GetCostMap() { return cost_map_; }
void SetCostMapAndInputOutput(std::map<CostPtrKey, CostPtrList> &);
// For two operators u--->v, given the output tensor layout of u,
// and the input tensor layout of v, return the redistribution cost,
// and the op_list to carry out the redistribution.

View File

@ -794,6 +794,191 @@ OperatorInfoPtr CostGraph::CheckContractElimination() const {
return nullptr;
}
std::pair<OperatorInfoPtr, OperatorInfoPtr> CostGraph::CheckSourceElimination() const {
size_t source_count = 0;
std::vector<OperatorInfoPtr> op_vector(2, nullptr);
for (auto &op : ops_) {
MS_EXCEPTION_IF_NULL(op);
bool bool_test = op->is_alive() && op->GetAlivePrevEdges().empty() && op->GetAliveSuccEdges().size() > 0;
if (bool_test) {
op_vector[source_count++] = op;
if (source_count == 2) {
return std::make_pair(op_vector[0], op_vector[1]);
}
}
}
return std::make_pair(nullptr, nullptr);
}
void CostGraph::CreateSourceEliminationSubCostList(StrategyPtr op1_old_stra, const CostPtrList &op1_old_clist,
StrategyPtr op2_old_stra, const CostPtrList &op2_old_clist,
CostPtrList *op1_new_clist) {
for (auto &op1_cost : op1_old_clist) {
for (auto &op2_cost : op2_old_clist) {
double computation = op1_cost->computation_cost_ + op2_cost->computation_cost_;
double memory = op1_cost->memory_with_reuse_ + op2_cost->memory_with_reuse_;
double communication = op1_cost->communication_cost_ + op2_cost->communication_cost_;
double communication_forward = op1_cost->communication_forward_ + op2_cost->communication_forward_;
double communication_without_para =
op1_cost->communication_without_parameter_ + op2_cost->communication_without_parameter_;
auto decision = std::make_shared<SourceEliminationDecision>(op1_old_stra, op1_cost, op2_old_stra, op2_cost);
auto new_cost = std::make_shared<Cost>(computation, communication, decision);
MS_EXCEPTION_IF_NULL(new_cost);
new_cost->communication_without_parameter_ = communication_without_para;
new_cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
new_cost->memory_with_reuse_ = memory;
new_cost->communication_forward_ = communication_forward;
MS_EXCEPTION_IF_NULL(op1_new_clist);
op1_new_clist->emplace_back(std::move(new_cost));
}
}
}
std::pair<std::vector<std::shared_ptr<Edge>>, std::vector<std::shared_ptr<Edge>>> CostGraph::EliminationSources(
OperatorInfoPtr op1, OperatorInfoPtr op2) {
MS_EXCEPTION_IF_NULL(op1);
MS_EXCEPTION_IF_NULL(op2);
MS_LOG(INFO) << "Now source eliminating node: " << op2->name() << " to node: " << op1->name();
auto op1_old_succ_edges = op1->GetAliveSuccEdges();
std::vector<std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>>> op1_edges_reorganised_cost(
op1_old_succ_edges.size());
std::vector<std::map<CostPtrKey, CostPtrList>> op1_new_edges_cost(op1_old_succ_edges.size());
std::vector<std::shared_ptr<Edge>> op1_new_succ_edges(op1_old_succ_edges.size());
auto op2_old_succ_edges = op2->GetAliveSuccEdges();
std::vector<std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>>> op2_edges_reorganised_cost(
op2_old_succ_edges.size());
std::vector<std::map<CostPtrKey, CostPtrList>> op2_new_edges_cost(op2_old_succ_edges.size());
std::vector<std::shared_ptr<Edge>> op2_new_succ_edges(op2_old_succ_edges.size());
// Construct cost_map for the data_structure of 'op1_edges_reorganised_cost' and 'op2_edges_reorganised_cost'
for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) {
const auto &op1_cost_map = op1_old_succ_edges[i]->GetCostMap();
std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>> from_tocost;
for (const auto &key_value : op1_cost_map) {
const auto &from_to_strategies = key_value.first;
const auto &costlist = key_value.second;
from_tocost[from_to_strategies.first].emplace_back(std::make_pair(from_to_strategies.second, costlist));
}
op1_edges_reorganised_cost[i] = from_tocost;
}
for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) {
const auto &op2_cost_map = op2_old_succ_edges[i]->GetCostMap();
std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>> from_tocost;
for (const auto &key_value : op2_cost_map) {
const auto &from_to_strategies = key_value.first;
const auto &costlist = key_value.second;
from_tocost[from_to_strategies.first].emplace_back(std::make_pair(from_to_strategies.second, costlist));
}
op2_edges_reorganised_cost[i] = from_tocost;
}
// Merge op2 into op1
const auto &op1_old_stra_cost = op1->GetStrategyCost();
const auto &op2_old_stra_cost = op2->GetStrategyCost();
std::vector<std::shared_ptr<StrategyWithCost>> op1_new_stra_cost;
for (auto &op1_stra_cost : op1_old_stra_cost) {
auto op1_old_stra = op1_stra_cost->strategy_ptr;
auto op1_old_costlist = op1_stra_cost->cost_list;
for (auto &op2_stra_cost : op2_old_stra_cost) {
auto op2_stra = op2_stra_cost->strategy_ptr;
auto op2_costlist = op2_stra_cost->cost_list;
StrategyPtr op1_new_stra = std::make_shared<Strategy>(*op1_old_stra);
op1_new_stra->CoverStrategy(op2_stra);
CostPtrList op1_new_costlist;
// Calculate new cost for 'op1_new_costlist'
CreateSourceEliminationSubCostList(op1_old_stra, op1_old_costlist, op2_stra, op2_costlist, &op1_new_costlist);
std::shared_ptr<StrategyWithCost> swc = std::make_shared<StrategyWithCost>(op1_new_stra, op1_new_costlist);
op1_new_stra_cost.emplace_back(swc);
// Set cost for new successive edges of op1 and op2
for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) {
auto &from_tocost = op1_edges_reorganised_cost[i];
auto &to_cost = from_tocost[op1_old_stra];
auto &new_cost_map = op1_new_edges_cost[i];
for (auto &stra_costlit : to_cost) {
auto &to_strategy = stra_costlit.first;
auto &edge_costlist = stra_costlit.second;
CostPtrKey new_key = {op1_new_stra, to_strategy};
new_cost_map[new_key] = edge_costlist;
}
}
for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) {
auto &from_tocost = op2_edges_reorganised_cost[i];
auto &to_cost = from_tocost[op2_stra];
auto &new_cost_map = op2_new_edges_cost[i];
for (auto &stra_costlist : to_cost) {
auto &to_strategy = stra_costlist.first;
auto &edge_costlist = stra_costlist.second;
CostPtrKey new_key = {op1_new_stra, to_strategy};
new_cost_map[new_key] = edge_costlist;
}
}
}
}
op1->SetStrategyCost(op1_new_stra_cost);
op2->SetNotAlive();
// Update the edges incident to op1, and edges incident to op2
for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) {
auto &new_cost_map = op1_new_edges_cost[i];
auto &ith_edge = op1_old_succ_edges[i];
std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + ith_edge->next_operator()->name();
std::shared_ptr<Edge> new_edge;
if (ith_edge->is_combined()) {
std::vector<size_t> output_indexs, input_indexs;
output_indexs = ith_edge->prev_op_output_indexs();
input_indexs = ith_edge->next_op_input_indexs();
new_edge =
std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_indexs, input_indexs, true);
} else {
size_t output_index, input_index;
output_index = ith_edge->prev_op_output_index();
input_index = ith_edge->next_op_input_index();
new_edge =
std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_index, input_index, false);
}
new_edge->SetCostMapAndInputOutput(new_cost_map);
// replace the old successive edges with the new ones.
op1->ReplaceSuccEdge(ith_edge->next_operator(), new_edge);
ith_edge->next_operator()->ReplacePreEdge(op1, new_edge);
op1_new_succ_edges[i] = new_edge;
}
for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) {
auto &new_cost_map = op2_new_edges_cost[i];
auto &ith_edge = op2_old_succ_edges[i];
const auto &destination = ith_edge->next_operator();
std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + destination->name();
std::shared_ptr<Edge> new_edge;
if (ith_edge->is_combined()) {
std::vector<size_t> output_indexs, input_indexs;
output_indexs = ith_edge->prev_op_output_indexs();
input_indexs = ith_edge->next_op_input_indexs();
new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_indexs, input_indexs, true);
} else {
size_t output_index, input_index;
output_index = ith_edge->prev_op_output_index();
input_index = ith_edge->next_op_input_index();
new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_index, input_index, false);
}
new_edge->SetCostMapAndInputOutput(new_cost_map);
// replace the old successive edges with the new ones.
destination->ReplacePreEdge(op2, new_edge);
op1->AddSuccEdge(new_edge);
op2_new_succ_edges[i] = new_edge;
}
MS_LOG(INFO) << "Source eliminating node: " << op2->name() << " to node: " << op1->name() + " succeeded.";
return {op1_new_succ_edges, op2_new_succ_edges};
}
// Check the graph whether a TriangleElimination can be performed
std::pair<OperatorInfoPtr, std::shared_ptr<Edge>> CostGraph::CheckTriangleElimination() const {
for (auto &op : ops_) {

View File

@ -180,6 +180,14 @@ class CostGraph {
void CreateStarEliminationSubCostList(const StrategyPtr &, const CostPtrList &, const CostPtrList &,
const StrategyPtr &, const CostPtrList &, std::vector<StrategyPtr>,
CostPtrList &, CostPtrList &, CostPtrList *);
// Return <op1, op2>. we merge 'op2' into 'op1'
std::pair<OperatorInfoPtr, OperatorInfoPtr> CheckSourceElimination() const;
void CreateSourceEliminationSubCostList(StrategyPtr, const CostPtrList &, StrategyPtr, const CostPtrList &,
CostPtrList *);
// We merge 'op2' into op1. The returned value are '<Edges1, Edges2>'. 'Edges1' are newly updated edges for 'op1',
// 'Edges2' are newly updated edges for 'op2'.
std::pair<std::vector<std::shared_ptr<Edge>>, std::vector<std::shared_ptr<Edge>>> EliminationSources(
OperatorInfoPtr op1, OperatorInfoPtr op2);
// Calculate memory cost for training phase or inference phase.
Status CalculateMemoryCost();
// When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then

View File

@ -1330,5 +1330,9 @@ void OperatorInfo::CheckSelectedStrategy(const StrategyPtr &s_strategy) {
PrintStrategy(s_strategy);
}
}
void OperatorInfo::SetStrategyCost(const std::vector<std::shared_ptr<StrategyWithCost>> &stra_cost) {
strategy_cost_ = stra_cost;
}
} // namespace parallel
} // namespace mindspore

View File

@ -97,6 +97,7 @@ class OperatorInfo {
// is checked
Status SetCostUnderStrategyBase(const StrategyPtr &strategy);
std::vector<std::shared_ptr<StrategyWithCost>> GetStrategyCost() { return strategy_cost_; }
void SetStrategyCost(const std::vector<std::shared_ptr<StrategyWithCost>> &);
// In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving
// WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory
// at the end of forward phase.

View File

@ -36,7 +36,19 @@ using StrategyPtr = std::shared_ptr<Strategy>;
class Strategy {
public:
Strategy(int32_t stage, std::vector<Dimensions> inputs) : stage_(stage), inputs_(std::move(inputs)) {}
Strategy(int32_t stage, std::vector<Dimensions> inputs)
: stage_(stage), inputs_(std::move(inputs)), internal_size_(0), internal_stragies_() {}
Strategy(const Strategy &another_stra) : stage_(another_stra.GetInputStage()) {
inputs_ = another_stra.GetInputDim();
internal_size_ = another_stra.GetInternalSize();
if (internal_size_ != 0) {
internal_stragies_ = another_stra.GetInternalStrategies();
} else {
internal_stragies_ = {};
}
}
~Strategy() = default;
size_t GetInputNumber() const { return inputs_.size(); }
std::vector<Dimensions> GetInputDim() const { return inputs_; }
@ -47,7 +59,10 @@ class Strategy {
}
}
void ResetInputs(const std::vector<Dimensions> &input) { inputs_ = input; }
std::vector<StrategyPtr> GetInternalStrategies() const { return internal_stragies_; }
size_t GetInternalSize() const { return internal_size_; }
// TODO(Xiaoda): need fix for adapting 'CoverStrategy'
bool IsEqual(const StrategyPtr &another_stra) {
if (another_stra == nullptr) {
return false;
@ -58,11 +73,19 @@ class Strategy {
return true;
}
// Include 'another_stra' into this strategy
void CoverStrategy(const StrategyPtr &another_stra) {
internal_stragies_.push_back(another_stra);
internal_size_++;
}
private:
const int32_t stage_;
// The size of Dimensions must equal to inputs_ tensor dimension.
std::vector<Dimensions> inputs_;
size_t internal_size_ = 0;
std::vector<StrategyPtr> internal_stragies_;
};
inline StrategyPtr NewStrategy(const int32_t stage, const std::vector<Dimensions> &inputs) {

View File

@ -0,0 +1,114 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.common.api import _executor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from tests.ut.python.ops.test_math_ops import VirtualLoss
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x, y, z, w, a):
predict = self.network(x, y, z, w, a)
return self.loss(predict)
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x, y, z, w, a):
return C.grad_all(self.network)(x, y, z, w, a)
# model_parallel test
def test_double_source_graph():
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.matmul1 = P.MatMul()
self.matmul2 = P.MatMul()
self.matmul3 = P.MatMul()
self.matmul4 = P.MatMul()
self.matmul5 = P.MatMul()
def construct(self, x, y, z, w, a):
m1_result = self.matmul1(x, y)
m2_result = self.matmul2(z, w)
m3_result = self.matmul3(m2_result, m1_result)
m4_result = self.matmul4(m2_result, m1_result)
out = self.matmul5(m3_result, m4_result)
return out
size = 8
context.set_auto_parallel_context(device_num=size, global_rank=0)
x = Tensor(np.ones([32, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 32]), dtype=ms.float32)
z = Tensor(np.ones([32, 32]), dtype=ms.float32)
w = Tensor(np.ones([32, 32]), dtype=ms.float32)
a = Tensor(np.ones([32, 32]), dtype=ms.float32)
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
_executor.compile(net, x, y, z, w, a)
def test_double_source_complex_graph():
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.matmul1 = P.MatMul()
self.matmul2 = P.MatMul()
self.matmul3 = P.MatMul()
self.matmul4 = P.MatMul()
self.matmul5 = P.MatMul()
self.matmul6 = P.MatMul()
def construct(self, x, y, z, w, a):
m1_result = self.matmul1(x, y)
m6_result = self.matmul6(m1_result, a)
m2_result = self.matmul2(z, w)
m3_result = self.matmul3(m2_result, m6_result)
m4_result = self.matmul4(m2_result, m1_result)
out = self.matmul5(m3_result, m4_result)
return out
size = 8
context.set_auto_parallel_context(device_num=size, global_rank=0)
x = Tensor(np.ones([32, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 32]), dtype=ms.float32)
z = Tensor(np.ones([32, 32]), dtype=ms.float32)
w = Tensor(np.ones([32, 32]), dtype=ms.float32)
a = Tensor(np.ones([32, 32]), dtype=ms.float32)
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
_executor.compile(net, x, y, z, w, a)