forked from mindspore-Ecosystem/mindspore
support reshape in sharding propagation:
1) using 'swc index of strategy_cost_' as reshape's selected strategy; 2) when encountering reshape in BFS, select the 'swc index' with zero communication cost; 3) when encountering a reshape that is already visited, check whether there exists communication between reshape and current operator. It is OK if communication happens between two configured operators; 4) currently, two consecutive reshapes are not supported; 5) adjusting BFS structure in graph_costmodel.cc; 6) adjusting some code in step_auto_parallel.cc to avoid cyclomatic complexity.
This commit is contained in:
parent
5ada60b949
commit
a772767265
|
@ -23,6 +23,7 @@
|
|||
#include "frontend/parallel/auto_parallel/costmodel.h"
|
||||
#include "frontend/parallel/auto_parallel/graph_costmodel.h"
|
||||
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
|
||||
#include "frontend/parallel/ops_info/reshape_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
@ -394,6 +395,122 @@ StrategyPtr Edge::GetPrevOpStrategyByNextOpStrategyWithZeroComm(const StrategyPt
|
|||
return prev_op_stras[0].first;
|
||||
}
|
||||
|
||||
int64_t Edge::GetReshapeSWCIndexByNextOpStrategy(const StrategyPtr &next_op_stra, int64_t curr_depth,
|
||||
const std::map<OperatorInfoPtr, StrategyPtr> &configured_ops) {
|
||||
if (prev_op_->name().find(RESHAPEINFO) == std::string::npos) {
|
||||
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s prev_op is not a Reshape.";
|
||||
}
|
||||
if (next_op_->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently.";
|
||||
}
|
||||
const auto &reshape_output_layout = next_op_->GetInputLayoutFromSWCByStrategy(next_op_stra, next_op_input_index_);
|
||||
MS_LOG(INFO) << prev_op_->name() << "'s output layout: " << reshape_output_layout.ToString();
|
||||
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(prev_op_);
|
||||
// First, try to find the zero communication strategy.
|
||||
auto swc_index = reshape_ptr->GetSWCIndexByOutputLayoutWithZeroComm(reshape_output_layout);
|
||||
if (swc_index == -1 && curr_depth == 0) {
|
||||
const auto &prev_edges = reshape_ptr->prev_edges();
|
||||
if (!prev_edges.empty()) {
|
||||
auto prev_edge = prev_edges[0];
|
||||
if (configured_ops.find(prev_edge->prev_operator()) != configured_ops.end()) {
|
||||
// Here, it is sure that Reshape's previous and next operators are both configured strategies,
|
||||
// thus, it is OK that communication happens here.
|
||||
swc_index = reshape_ptr->GetSWCIndexByOutputLayout(reshape_output_layout);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (swc_index == -1) {
|
||||
MS_LOG(EXCEPTION) << "No available strategy found at edge: " << edge_name_ << " for: " << prev_op_->name();
|
||||
}
|
||||
return swc_index;
|
||||
}
|
||||
|
||||
int64_t Edge::GetReshapeSWCIndexByPrevOpStrategy(const StrategyPtr &prev_op_stra, int64_t curr_depth,
|
||||
const std::map<OperatorInfoPtr, StrategyPtr> &configured_ops) {
|
||||
if (next_op_->name().find(RESHAPEINFO) == std::string::npos) {
|
||||
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s next_op is not a Reshape.";
|
||||
}
|
||||
if (prev_op_->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently.";
|
||||
}
|
||||
const auto &reshape_input_lyt = prev_op_->GetOutputLayoutFromSWCByStrategy(prev_op_stra, prev_op_output_index_);
|
||||
MS_LOG(INFO) << next_op_->name() << "'s input layout: " << reshape_input_lyt.ToString();
|
||||
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(next_op_);
|
||||
// First, try to find the zero communication strategy.
|
||||
auto swc_index = reshape_ptr->GetSWCIndexByInputLayoutWithZeroComm(reshape_input_lyt);
|
||||
if (swc_index == -1 && curr_depth == 0) {
|
||||
const auto &next_edges = reshape_ptr->succ_edges();
|
||||
if (!next_edges.empty()) {
|
||||
auto next_edge = next_edges[0];
|
||||
if (configured_ops.find(next_edge->next_operator()) != configured_ops.end()) {
|
||||
// Here, it is sure that Reshape's previous and next operators are both configured strategies,
|
||||
// thus, it is OK that communication happens here.
|
||||
swc_index = reshape_ptr->GetSWCIndexByInputLayout(reshape_input_lyt);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (swc_index == -1) {
|
||||
MS_LOG(EXCEPTION) << "No available strategy found at edge: " << edge_name_ << " for: " << next_op_->name();
|
||||
}
|
||||
return swc_index;
|
||||
}
|
||||
|
||||
StrategyPtr Edge::GetPrevOpStrategyByReshapeSWCIndex(int64_t swc_index) {
|
||||
if (next_op_->name().find(RESHAPEINFO) == std::string::npos) {
|
||||
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s next_op is not a Reshape.";
|
||||
}
|
||||
if (prev_op_->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently.";
|
||||
}
|
||||
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(next_op_);
|
||||
const auto &reshape_input_lyt = reshape_ptr->GetInputLayoutBySWCIndex(swc_index);
|
||||
auto stra = prev_op_->GetStrategyFromSWCByOutputLayout(reshape_input_lyt, prev_op_output_index_);
|
||||
if (stra == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "No available strategy found at edge: " << edge_name_ << " for: " << prev_op_->name();
|
||||
}
|
||||
return stra;
|
||||
}
|
||||
|
||||
StrategyPtr Edge::GetNextOpStrategyByReshapeSWCIndex(int64_t swc_index) {
|
||||
if (prev_op_->name().find(RESHAPEINFO) == std::string::npos) {
|
||||
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s next_op is not a Reshape.";
|
||||
}
|
||||
if (next_op_->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently.";
|
||||
}
|
||||
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(prev_op_);
|
||||
const auto &reshape_output_lyt = reshape_ptr->GetOutputLayoutBySWCIndex(swc_index);
|
||||
auto stra = next_op_->GetStrategyFromSWCByInputLayout(reshape_output_lyt, next_op_input_index_);
|
||||
if (stra == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "No available strategy found at edge: " << edge_name_ << " for: " << prev_op_->name();
|
||||
}
|
||||
return stra;
|
||||
}
|
||||
|
||||
bool Edge::CheckStrategyConsistency(const std::map<OperatorInfoPtr, StrategyPtr> &configured_ops) {
|
||||
auto prev_stra = prev_op_->selected_strategy();
|
||||
auto next_stra = next_op_->selected_strategy();
|
||||
if (prev_stra == nullptr) {
|
||||
MS_LOG(EXCEPTION) << prev_op_->name() << "'s selected strategy is null!";
|
||||
}
|
||||
if (next_op_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << next_op_->name() << "'s selected strategy is null!";
|
||||
}
|
||||
auto cost = GetCostByStrategyPair({prev_stra, next_stra});
|
||||
if ((configured_ops.find(prev_op_) == configured_ops.end() ||
|
||||
configured_ops.find(next_op_) == configured_ops.end()) &&
|
||||
(cost == nullptr || cost->communication_cost_ > 0.0)) {
|
||||
PrintStrategy(prev_op_->selected_strategy());
|
||||
PrintStrategy(next_op_->selected_strategy());
|
||||
MS_LOG(ERROR) << "There are redistribution cost occurs at edge: " << edge_name()
|
||||
<< ", consider configuring sharding strategies for two operators."
|
||||
<< " The full name of these two operators are: " << prev_op_->cnode()->fullname_with_scope()
|
||||
<< " and " << next_op_->cnode()->fullname_with_scope();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void Edge::SetCostMapAndInputOutput(std::map<CostPtrKey, CostPtrList> &cost_map) {
|
||||
cost_map_ = cost_map;
|
||||
pre_op_output_.clear();
|
||||
|
|
|
@ -82,8 +82,17 @@ class Edge {
|
|||
Status InitEdgeCost();
|
||||
std::map<CostPtrKey, CostPtrList> GetCostMap() { return cost_map_; }
|
||||
CostPtr GetCostByStrategyPair(const CostPtrKey &);
|
||||
|
||||
StrategyPtr GetNextOpStrategyByPrevOpStrategyWithZeroComm(const StrategyPtr &);
|
||||
StrategyPtr GetPrevOpStrategyByNextOpStrategyWithZeroComm(const StrategyPtr &);
|
||||
int64_t GetReshapeSWCIndexByNextOpStrategy(const StrategyPtr &next_op_stra, int64_t curr_depth,
|
||||
const std::map<OperatorInfoPtr, StrategyPtr> &configured_ops);
|
||||
int64_t GetReshapeSWCIndexByPrevOpStrategy(const StrategyPtr &prev_op_stra, int64_t curr_depth,
|
||||
const std::map<OperatorInfoPtr, StrategyPtr> &configured_ops);
|
||||
StrategyPtr GetPrevOpStrategyByReshapeSWCIndex(int64_t swc_index);
|
||||
StrategyPtr GetNextOpStrategyByReshapeSWCIndex(int64_t swc_index);
|
||||
bool CheckStrategyConsistency(const std::map<OperatorInfoPtr, StrategyPtr> &configured_ops);
|
||||
|
||||
void SetCostMapAndInputOutput(std::map<CostPtrKey, CostPtrList> &);
|
||||
// For two operators u--->v, given the output tensor layout of u,
|
||||
// and the input tensor layout of v, return the redistribution cost,
|
||||
|
@ -166,7 +175,7 @@ class Edge {
|
|||
// operators. The resulting edges are different from those normal edges, thus this Bool variable distinguishes them.
|
||||
// If it is true, then we should guarantee that the strategy for output tensor consistent with the input tensor.
|
||||
bool is_identity_edge;
|
||||
CostPtr selected_cost_;
|
||||
CostPtr selected_cost_ = nullptr;
|
||||
// In the training phase, 'is_output_parameter_involve_' is used to mark whether the output of the previous operator
|
||||
// is parameter-involved
|
||||
int64_t is_output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved
|
||||
|
|
|
@ -101,84 +101,157 @@ void CostGraph::StrategyPropagate(const std::map<OperatorInfoPtr, StrategyPtr> &
|
|||
}
|
||||
}
|
||||
|
||||
void CheckShardingConsisitency(std::map<OperatorInfoPtr, StrategyPtr> configured_ops, const OperatorInfoPtr &curr_op,
|
||||
const OperatorInfoPtr &another_op, const CostPtr &cost, const EdgePtr &edge) {
|
||||
if ((configured_ops.find(another_op) == configured_ops.end()) &&
|
||||
(cost == nullptr || cost->communication_cost_ != 0.0)) {
|
||||
PrintStrategy(another_op->selected_strategy());
|
||||
PrintStrategy(curr_op->selected_strategy());
|
||||
MS_LOG(EXCEPTION) << "There are redistribution cost occurs at edge: " << edge->edge_name()
|
||||
<< ", consider configuring sharding strategies for two operators."
|
||||
<< " The full name of these two operators are: " << curr_op->cnode()->fullname_with_scope()
|
||||
<< " and " << another_op->cnode()->fullname_with_scope();
|
||||
void CheckVisitedEdgeConsistency(const EdgePtr &edge, std::map<OperatorInfoPtr, StrategyPtr> configured_ops) {
|
||||
auto prev_op = edge->prev_operator();
|
||||
auto next_op = edge->next_operator();
|
||||
if (prev_op->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
const auto &reshape_output_lyt =
|
||||
next_op->GetInputLayoutFromSWCByStrategy(next_op->selected_strategy(), edge->next_op_input_index());
|
||||
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(prev_op);
|
||||
auto consistency =
|
||||
reshape_ptr->CheckStrategyConsistencyByOutputLayout(reshape_ptr->swc_index(), reshape_output_lyt);
|
||||
if (!consistency) {
|
||||
MS_LOG(EXCEPTION) << "Inconsistency occurred at edge: " << edge->edge_name();
|
||||
}
|
||||
} else if (next_op->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
const auto &reshape_input_lyt =
|
||||
prev_op->GetOutputLayoutFromSWCByStrategy(prev_op->selected_strategy(), edge->prev_op_output_index());
|
||||
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(next_op);
|
||||
auto consistency = reshape_ptr->CheckStrategyConsistencyByInputLayout(reshape_ptr->swc_index(), reshape_input_lyt);
|
||||
if (!consistency) {
|
||||
MS_LOG(EXCEPTION) << "Inconsistency occurred at edge: " << edge->edge_name();
|
||||
}
|
||||
} else {
|
||||
auto consistency = edge->CheckStrategyConsistency(configured_ops);
|
||||
if (!consistency) {
|
||||
MS_LOG(EXCEPTION) << "Inconsistency occurred at edge: " << edge->edge_name();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CheckConfiguredSuccEdgeConsistency(const EdgePtr edge, std::map<OperatorInfoPtr, StrategyPtr> configured_ops,
|
||||
int64_t curr_depth) {
|
||||
auto curr_op = edge->prev_operator();
|
||||
auto next_op = edge->next_operator();
|
||||
if ((curr_op->name().find(RESHAPEINFO) != std::string::npos) && curr_depth > 1) {
|
||||
auto next_op_conf_stra = configured_ops[next_op];
|
||||
const auto &reshape_output_lyt =
|
||||
next_op->GetInputLayoutFromSWCByStrategy(next_op_conf_stra, edge->next_op_input_index());
|
||||
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(curr_op);
|
||||
auto consistency =
|
||||
reshape_ptr->CheckStrategyConsistencyByOutputLayout(reshape_ptr->swc_index(), reshape_output_lyt);
|
||||
if (!consistency) {
|
||||
MS_LOG(EXCEPTION) << "Inconsistency occurred at edge: " << edge->edge_name();
|
||||
}
|
||||
} else if (curr_op->name().find(RESHAPEINFO) == std::string::npos) {
|
||||
const auto &next_op_conf_stra = configured_ops[next_op];
|
||||
const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithZeroComm(curr_op->selected_strategy());
|
||||
if ((next_op_conf_stra == nullptr) || (!next_op_conf_stra->IsEqual(next_op_stra))) {
|
||||
MS_LOG(EXCEPTION) << "Sharding strategies should be configured on the boundary operators. "
|
||||
<< "Currently reaching " << curr_op->name() << " and " << next_op->name() << "."
|
||||
<< " The full name of these two operators are: " << curr_op->cnode()->fullname_with_scope()
|
||||
<< " and " << next_op->cnode()->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CheckConfiguredPrevEdgeConsistency(const EdgePtr edge, std::map<OperatorInfoPtr, StrategyPtr> configured_ops,
|
||||
int64_t curr_depth) {
|
||||
auto curr_op = edge->next_operator();
|
||||
auto prev_op = edge->prev_operator();
|
||||
if (curr_op->name().find(RESHAPEINFO) != std::string::npos && curr_depth > 1) {
|
||||
auto prev_op_conf_stra = configured_ops[prev_op];
|
||||
const auto &reshape_input_lyt =
|
||||
prev_op->GetOutputLayoutFromSWCByStrategy(prev_op_conf_stra, edge->prev_op_output_index());
|
||||
auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(curr_op);
|
||||
auto consistency = reshape_ptr->CheckStrategyConsistencyByInputLayout(reshape_ptr->swc_index(), reshape_input_lyt);
|
||||
if (!consistency) {
|
||||
MS_LOG(EXCEPTION) << "Inconsistency occurred at edge: " << edge->edge_name();
|
||||
}
|
||||
} else if (curr_op->name().find(RESHAPEINFO) == std::string::npos) {
|
||||
const auto &prev_op_conf_stra = configured_ops[prev_op];
|
||||
const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithZeroComm(curr_op->selected_strategy());
|
||||
if ((prev_op_conf_stra == nullptr) || (!prev_op_conf_stra->IsEqual(prev_op_stra))) {
|
||||
MS_LOG(ERROR) << "curr_depth: " << curr_depth;
|
||||
MS_LOG(EXCEPTION) << "Sharding strategies should be configured on the boundary operators. "
|
||||
<< "Currently reaching " << prev_op->name() << " and " << curr_op->name() << "."
|
||||
<< " The full name of these two operators are: " << prev_op->cnode()->fullname_with_scope()
|
||||
<< " and " << curr_op->cnode()->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CostGraph::BFS(const OperatorInfoPtr &op, const StrategyPtr &op_stra,
|
||||
std::map<OperatorInfoPtr, StrategyPtr> configured_ops, std::map<OperatorInfoPtr, bool> *visited) {
|
||||
std::queue<std::pair<std::pair<OperatorInfoPtr, StrategyPtr>, size_t>> next_level;
|
||||
(void)next_level.emplace(std::make_pair(op, op_stra), 0);
|
||||
std::queue<std::pair<std::pair<OperatorInfoPtr, std::pair<StrategyPtr, int64_t>>, int64_t>> next_level;
|
||||
(void)next_level.emplace(std::make_pair(op, std::make_pair(op_stra, -1)), 0);
|
||||
while (!next_level.empty()) {
|
||||
auto curr_op = next_level.front().first.first;
|
||||
auto configured_stra = next_level.front().first.second;
|
||||
auto configured_stra = next_level.front().first.second.first;
|
||||
auto configured_stra_index = next_level.front().first.second.second;
|
||||
auto curr_depth = next_level.front().second;
|
||||
visited->at(curr_op) = true;
|
||||
MS_LOG(INFO) << "curr_depth: " << curr_depth;
|
||||
curr_op->SetSelectedStrategy(configured_stra, curr_depth);
|
||||
if (curr_op->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
curr_op->set_swc_index(configured_stra_index, curr_depth);
|
||||
} else {
|
||||
curr_op->SetSelectedStrategy(configured_stra, curr_depth);
|
||||
}
|
||||
for (auto &edge : curr_op->succ_edges()) {
|
||||
const auto &next_op = edge->next_operator();
|
||||
if (visited->at(next_op)) {
|
||||
const auto cost = edge->GetCostByStrategyPair({curr_op->selected_strategy(), next_op->selected_strategy()});
|
||||
CheckShardingConsisitency(configured_ops, curr_op, next_op, cost, edge);
|
||||
CheckVisitedEdgeConsistency(edge, configured_ops);
|
||||
continue;
|
||||
}
|
||||
if ((curr_depth > 0) && (configured_ops.find(next_op) != configured_ops.end())) {
|
||||
const auto &next_op_conf_stra = configured_ops[next_op];
|
||||
const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithZeroComm(curr_op->selected_strategy());
|
||||
if ((next_op_conf_stra == nullptr) || (!next_op_conf_stra->IsEqual(next_op_stra))) {
|
||||
MS_LOG(EXCEPTION) << "Sharding strategies should be configured on the boundary operators. "
|
||||
<< "Currently reaching " << curr_op->name() << " and " << next_op->name() << "."
|
||||
<< " The full name of these two operators are: " << curr_op->cnode()->fullname_with_scope()
|
||||
<< " and " << next_op->cnode()->fullname_with_scope();
|
||||
}
|
||||
CheckConfiguredSuccEdgeConsistency(edge, configured_ops, curr_depth);
|
||||
}
|
||||
if (configured_ops.find(next_op) != configured_ops.end()) {
|
||||
continue;
|
||||
}
|
||||
const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithZeroComm(curr_op->selected_strategy());
|
||||
if (next_op_stra == nullptr) {
|
||||
PrintStrategy(curr_op->selected_strategy());
|
||||
MS_LOG(EXCEPTION) << next_op->name() << "'s strategy is null in the edge: " << edge->edge_name();
|
||||
if (curr_op->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
auto stra = edge->GetNextOpStrategyByReshapeSWCIndex(curr_op->swc_index());
|
||||
(void)next_level.emplace(std::make_pair(next_op, std::make_pair(stra, -1)), curr_depth + 1);
|
||||
} else if (next_op->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
auto swc_index =
|
||||
edge->GetReshapeSWCIndexByPrevOpStrategy(curr_op->selected_strategy(), curr_depth, configured_ops);
|
||||
(void)next_level.emplace(std::make_pair(next_op, std::make_pair(nullptr, swc_index)), curr_depth + 1);
|
||||
} else {
|
||||
const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithZeroComm(curr_op->selected_strategy());
|
||||
if (next_op_stra == nullptr) {
|
||||
PrintStrategy(curr_op->selected_strategy());
|
||||
MS_LOG(EXCEPTION) << next_op->name() << "'s strategy is null in the edge: " << edge->edge_name();
|
||||
}
|
||||
(void)next_level.emplace(std::make_pair(next_op, std::make_pair(next_op_stra, -1)), curr_depth + 1);
|
||||
}
|
||||
(void)next_level.emplace(std::make_pair(next_op, next_op_stra), curr_depth + 1);
|
||||
}
|
||||
for (auto &edge : curr_op->prev_edges()) {
|
||||
const auto &prev_op = edge->prev_operator();
|
||||
if (visited->at(prev_op)) {
|
||||
const auto cost = edge->GetCostByStrategyPair({prev_op->selected_strategy(), curr_op->selected_strategy()});
|
||||
CheckShardingConsisitency(configured_ops, curr_op, prev_op, cost, edge);
|
||||
CheckVisitedEdgeConsistency(edge, configured_ops);
|
||||
continue;
|
||||
}
|
||||
if ((curr_depth > 0) && (configured_ops.find(prev_op) != configured_ops.end())) {
|
||||
const auto &prev_op_conf_stra = configured_ops[prev_op];
|
||||
const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithZeroComm(curr_op->selected_strategy());
|
||||
if ((prev_op_conf_stra == nullptr) || (!prev_op_conf_stra->IsEqual(prev_op_stra))) {
|
||||
MS_LOG(ERROR) << "curr_depth: " << curr_depth;
|
||||
MS_LOG(EXCEPTION) << "Sharding strategies should be configured on the boundary operators. "
|
||||
<< "Currently reaching " << prev_op->name() << " and " << curr_op->name() << "."
|
||||
<< " The full name of these two operators are: " << prev_op->cnode()->fullname_with_scope()
|
||||
<< " and " << curr_op->cnode()->fullname_with_scope();
|
||||
}
|
||||
CheckConfiguredPrevEdgeConsistency(edge, configured_ops, curr_depth);
|
||||
}
|
||||
if (configured_ops.find(prev_op) != configured_ops.end()) {
|
||||
continue;
|
||||
}
|
||||
const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithZeroComm(curr_op->selected_strategy());
|
||||
if (prev_op_stra == nullptr) {
|
||||
PrintStrategy(curr_op->selected_strategy());
|
||||
MS_LOG(EXCEPTION) << prev_op->name() << "'s strategy is null in the edge: " << edge->edge_name();
|
||||
if (prev_op->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
auto swc_index =
|
||||
edge->GetReshapeSWCIndexByNextOpStrategy(curr_op->selected_strategy(), curr_depth, configured_ops);
|
||||
(void)next_level.emplace(std::make_pair(prev_op, std::make_pair(nullptr, swc_index)), curr_depth + 1);
|
||||
} else if (curr_op->name().find(RESHAPEINFO) != std::string::npos) {
|
||||
auto prev_stra = edge->GetPrevOpStrategyByReshapeSWCIndex(curr_op->swc_index());
|
||||
(void)next_level.emplace(std::make_pair(prev_op, std::make_pair(prev_stra, -1)), curr_depth + 1);
|
||||
} else {
|
||||
const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithZeroComm(curr_op->selected_strategy());
|
||||
if (prev_op_stra == nullptr) {
|
||||
PrintStrategy(curr_op->selected_strategy());
|
||||
MS_LOG(EXCEPTION) << prev_op->name() << "'s strategy is null in the edge: " << edge->edge_name();
|
||||
}
|
||||
(void)next_level.emplace(std::make_pair(prev_op, std::make_pair(prev_op_stra, -1)), curr_depth + 1);
|
||||
}
|
||||
(void)next_level.emplace(std::make_pair(prev_op, prev_op_stra), curr_depth + 1);
|
||||
}
|
||||
next_level.pop();
|
||||
}
|
||||
|
|
|
@ -1363,6 +1363,50 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
TensorLayout OperatorInfo::GetInputLayoutFromSWCByStrategy(StrategyPtr stra, size_t input_index) {
|
||||
auto is_target = [&](std::shared_ptr<StrategyWithCost> swc) { return swc->strategy_ptr->IsEqual(stra); };
|
||||
auto it = std::find_if(strategy_cost_.begin(), strategy_cost_.end(), is_target);
|
||||
if (it != strategy_cost_.end()) {
|
||||
const auto &input_info = (*it)->inputs_ptr[input_index];
|
||||
return std::move(input_info.tensor_layout());
|
||||
}
|
||||
TensorLayout empty;
|
||||
return empty;
|
||||
}
|
||||
|
||||
TensorLayout OperatorInfo::GetOutputLayoutFromSWCByStrategy(StrategyPtr stra, size_t output_index) {
|
||||
auto is_target = [&](std::shared_ptr<StrategyWithCost> swc) { return swc->strategy_ptr->IsEqual(stra); };
|
||||
auto it = std::find_if(strategy_cost_.begin(), strategy_cost_.end(), is_target);
|
||||
if (it != strategy_cost_.end()) {
|
||||
const auto &output_info = (*it)->outputs_ptr[output_index];
|
||||
return std::move(output_info.tensor_layout());
|
||||
}
|
||||
TensorLayout empty;
|
||||
return empty;
|
||||
}
|
||||
|
||||
StrategyPtr OperatorInfo::GetStrategyFromSWCByInputLayout(TensorLayout input_layout, size_t input_index) {
|
||||
auto is_target = [&](std::shared_ptr<StrategyWithCost> swc) {
|
||||
return swc->inputs_ptr[input_index].tensor_layout() == input_layout;
|
||||
};
|
||||
auto it = std::find_if(strategy_cost_.begin(), strategy_cost_.end(), is_target);
|
||||
if (it != strategy_cost_.end()) {
|
||||
return (*it)->strategy_ptr;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
StrategyPtr OperatorInfo::GetStrategyFromSWCByOutputLayout(TensorLayout output_layout, size_t output_index) {
|
||||
auto is_target = [&](std::shared_ptr<StrategyWithCost> swc) {
|
||||
return swc->outputs_ptr[output_index].tensor_layout() == output_layout;
|
||||
};
|
||||
auto it = std::find_if(strategy_cost_.begin(), strategy_cost_.end(), is_target);
|
||||
if (it != strategy_cost_.end()) {
|
||||
return (*it)->strategy_ptr;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Keep at most (1.0 / epsilon) number of available strategies for each operator.
|
||||
void OperatorInfo::ApproximateStrategies() {
|
||||
auto enable_approxi = CostModelContext::GetInstance()->dp_algo_enable_approxi();
|
||||
|
@ -1683,6 +1727,12 @@ void OperatorInfo::SetSelectedStrategy(const StrategyPtr &s_strategy, size_t cur
|
|||
selected_strategy_depth_ = SizeToLong(curr_depth);
|
||||
}
|
||||
|
||||
void OperatorInfo::set_swc_index(int64_t swc, int64_t depth) {
|
||||
MS_LOG(INFO) << "Set SWC index: " << swc << " for: " << name();
|
||||
selected_strategy_depth_ = depth;
|
||||
swc_index_ = swc;
|
||||
}
|
||||
|
||||
CNodePtr OperatorInfo::cnode() {
|
||||
MS_EXCEPTION_IF_NULL(cnode_);
|
||||
return cnode_;
|
||||
|
|
|
@ -144,6 +144,14 @@ class OperatorInfo {
|
|||
void SetSelectedStrategy(const StrategyPtr &s_strategy, size_t);
|
||||
StrategyPtr selected_strategy() const { return selected_strategy_; }
|
||||
CostPtr selected_cost() const { return selected_cost_; }
|
||||
|
||||
TensorLayout GetInputLayoutFromSWCByStrategy(StrategyPtr stra, size_t input_index);
|
||||
TensorLayout GetOutputLayoutFromSWCByStrategy(StrategyPtr stra, size_t output_index);
|
||||
StrategyPtr GetStrategyFromSWCByInputLayout(TensorLayout input_layout, size_t input_index);
|
||||
StrategyPtr GetStrategyFromSWCByOutputLayout(TensorLayout output_layout, size_t output_index);
|
||||
|
||||
void set_swc_index(int64_t, int64_t);
|
||||
int64_t swc_index() { return swc_index_; }
|
||||
// Approximate the list of available strategies
|
||||
void ApproximateStrategies();
|
||||
// Make the list of available strategies exact and re-init the related edges incident to this operator
|
||||
|
@ -293,6 +301,7 @@ class OperatorInfo {
|
|||
private:
|
||||
OperatorCostPtr operator_cost_;
|
||||
std::vector<TypePtr> outputs_type_;
|
||||
int64_t swc_index_ = -1;
|
||||
};
|
||||
|
||||
Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy);
|
||||
|
|
|
@ -431,7 +431,7 @@ std::vector<StrategyPtr> ReshapeInfo::GenerateOpStrategies(int64_t) {
|
|||
return sp_vector;
|
||||
}
|
||||
|
||||
Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs,
|
||||
Status ReshapeInfo::GenerateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs,
|
||||
const std::vector<std::shared_ptr<StrategyWithCost>> &next_stra_costs,
|
||||
int64_t out_index, int64_t in_index, bool is_prev_param,
|
||||
bool is_next_reshape) {
|
||||
|
@ -488,7 +488,137 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<Stra
|
|||
if (strategy_cost_.empty()) {
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Print " << name() << "'s 'strategy_cost':";
|
||||
for (auto &swc : strategy_cost_) {
|
||||
MS_LOG(INFO) << name() << "'s strategy:";
|
||||
PrintStrategy(swc->strategy_ptr);
|
||||
MS_LOG(INFO) << "The corresponding cost: " << swc->cost_list[0]->computation_cost_ << ", "
|
||||
<< swc->cost_list[0]->communication_cost_ << ", "
|
||||
<< swc->cost_list[0]->communication_without_parameter_;
|
||||
MS_LOG(INFO) << "Input layout: " << swc->inputs_ptr[0].tensor_layout().ToString();
|
||||
MS_LOG(INFO) << "Output layout: " << swc->outputs_ptr[0].tensor_layout().ToString();
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
int64_t ReshapeInfo::GetSWCIndexByOutputLayoutWithZeroComm(const TensorLayout &output_layout) {
|
||||
std::vector<std::pair<int64_t, double>> index_computation;
|
||||
for (size_t i = 0; i < strategy_cost_.size(); ++i) {
|
||||
const auto &swc = strategy_cost_[i];
|
||||
if (swc->outputs_ptr[0].tensor_layout() == output_layout &&
|
||||
swc->cost_list[0]->communication_without_parameter_ == 0.0) {
|
||||
(void)index_computation.emplace_back(SizeToLong(i), swc->cost_list[0]->computation_cost_);
|
||||
}
|
||||
}
|
||||
if (index_computation.empty()) {
|
||||
MS_LOG(ERROR) << "There in no available strategy for zero communication cost for reshape: " << name();
|
||||
return -1;
|
||||
}
|
||||
if (index_computation.size() > 1) {
|
||||
MS_LOG(INFO) << "There are multiple strategies available for reshape: " << name();
|
||||
}
|
||||
std::sort(
|
||||
index_computation.begin(), index_computation.end(),
|
||||
[](const std::pair<size_t, double> &a, const std::pair<size_t, double> &b) { return a.second <= b.second; });
|
||||
return index_computation[0].first;
|
||||
}
|
||||
|
||||
int64_t ReshapeInfo::GetSWCIndexByOutputLayout(const TensorLayout &output_layout) {
|
||||
std::vector<std::pair<int64_t, double>> index_computation;
|
||||
for (size_t i = 0; i < strategy_cost_.size(); ++i) {
|
||||
const auto &swc = strategy_cost_[i];
|
||||
if (swc->outputs_ptr[0].tensor_layout() == output_layout) {
|
||||
(void)index_computation.emplace_back(SizeToLong(i), swc->cost_list[0]->computation_cost_);
|
||||
}
|
||||
}
|
||||
if (index_computation.empty()) {
|
||||
MS_LOG(ERROR) << "There in no available strategy for zero communication cost for reshape: " << name();
|
||||
return -1;
|
||||
}
|
||||
if (index_computation.size() > 1) {
|
||||
MS_LOG(INFO) << "There are multiple strategies available for reshape: " << name();
|
||||
}
|
||||
std::sort(
|
||||
index_computation.begin(), index_computation.end(),
|
||||
[](const std::pair<size_t, double> &a, const std::pair<size_t, double> &b) { return a.second <= b.second; });
|
||||
return index_computation[0].first;
|
||||
}
|
||||
|
||||
int64_t ReshapeInfo::GetSWCIndexByInputLayoutWithZeroComm(const TensorLayout &input_layout) {
|
||||
std::vector<std::pair<int64_t, double>> index_computation;
|
||||
for (size_t i = 0; i < strategy_cost_.size(); ++i) {
|
||||
const auto &swc = strategy_cost_[i];
|
||||
if (swc->inputs_ptr[0].tensor_layout() == input_layout &&
|
||||
swc->cost_list[0]->communication_without_parameter_ == 0.0) {
|
||||
(void)index_computation.emplace_back(SizeToLong(i), swc->cost_list[0]->computation_cost_);
|
||||
}
|
||||
}
|
||||
if (index_computation.empty()) {
|
||||
MS_LOG(ERROR) << "There in no available strategy for zero communication cost for reshape: " << name();
|
||||
return -1;
|
||||
}
|
||||
if (index_computation.size() > 1) {
|
||||
MS_LOG(INFO) << "There are multiple strategies available for reshape: " << name();
|
||||
}
|
||||
std::sort(
|
||||
index_computation.begin(), index_computation.end(),
|
||||
[](const std::pair<size_t, double> &a, const std::pair<size_t, double> &b) { return a.second <= b.second; });
|
||||
return index_computation[0].first;
|
||||
}
|
||||
|
||||
int64_t ReshapeInfo::GetSWCIndexByInputLayout(const TensorLayout &input_layout) {
|
||||
std::vector<std::pair<int64_t, double>> index_computation;
|
||||
for (size_t i = 0; i < strategy_cost_.size(); ++i) {
|
||||
const auto &swc = strategy_cost_[i];
|
||||
if (swc->inputs_ptr[0].tensor_layout() == input_layout) {
|
||||
(void)index_computation.emplace_back(SizeToLong(i), swc->cost_list[0]->computation_cost_);
|
||||
}
|
||||
}
|
||||
if (index_computation.empty()) {
|
||||
MS_LOG(ERROR) << "There in no available strategy for zero communication cost for reshape: " << name();
|
||||
return -1;
|
||||
}
|
||||
if (index_computation.size() > 1) {
|
||||
MS_LOG(INFO) << "There are multiple strategies available for reshape: " << name();
|
||||
}
|
||||
std::sort(
|
||||
index_computation.begin(), index_computation.end(),
|
||||
[](const std::pair<size_t, double> &a, const std::pair<size_t, double> &b) { return a.second <= b.second; });
|
||||
return index_computation[0].first;
|
||||
}
|
||||
|
||||
bool ReshapeInfo::CheckStrategyConsistencyByOutputLayout(int64_t swc_index, const TensorLayout &output_layout) {
|
||||
if (swc_index == -1 || swc_index >= SizeToLong(strategy_cost_.size())) {
|
||||
MS_LOG(ERROR) << "The strategy_index: " << swc_index << " is out of range.";
|
||||
return false;
|
||||
}
|
||||
const auto &swc = strategy_cost_[swc_index];
|
||||
return swc->outputs_ptr[0].tensor_layout() == output_layout;
|
||||
}
|
||||
|
||||
bool ReshapeInfo::CheckStrategyConsistencyByInputLayout(int64_t swc_index, const TensorLayout &input_layout) {
|
||||
if (swc_index == -1 || swc_index >= SizeToLong(strategy_cost_.size())) {
|
||||
MS_LOG(ERROR) << "The strategy_index: " << swc_index << " is out of range.";
|
||||
return false;
|
||||
}
|
||||
const auto &swc = strategy_cost_[swc_index];
|
||||
return swc->inputs_ptr[0].tensor_layout() == input_layout;
|
||||
}
|
||||
|
||||
TensorLayout ReshapeInfo::GetInputLayoutBySWCIndex(int64_t swc_index) {
|
||||
if (swc_index == -1 || swc_index >= SizeToLong(strategy_cost_.size())) {
|
||||
MS_LOG(EXCEPTION) << "The strategy_index: " << swc_index << " is out of range.";
|
||||
}
|
||||
const auto &swc = strategy_cost_[swc_index];
|
||||
return std::move(swc->inputs_ptr[0].tensor_layout());
|
||||
}
|
||||
|
||||
TensorLayout ReshapeInfo::GetOutputLayoutBySWCIndex(int64_t swc_index) {
|
||||
if (swc_index == -1 || swc_index >= SizeToLong(strategy_cost_.size())) {
|
||||
MS_LOG(EXCEPTION) << "The strategy_index: " << swc_index << " is out of range.";
|
||||
}
|
||||
const auto &swc = strategy_cost_[swc_index];
|
||||
return std::move(swc->outputs_ptr[0].tensor_layout());
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -58,7 +58,7 @@ class ReshapeInfo : public OperatorInfo {
|
|||
void set_next_operator_name(const std::string &next_name) { next_operator_name_ = next_name; }
|
||||
void set_pre_operator_index(int64_t pre_index) { pre_operator_index_ = pre_index; }
|
||||
void set_next_operator_index(int64_t next_index) { next_operator_index_ = next_index; }
|
||||
Status GenetateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs,
|
||||
Status GenerateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs,
|
||||
const std::vector<std::shared_ptr<StrategyWithCost>> &next_stra_costs, int64_t out_index,
|
||||
int64_t in_index, bool is_prev_param, bool is_next_reshape);
|
||||
Status GenerateStrategies(int64_t stage_id) override;
|
||||
|
@ -69,6 +69,16 @@ class ReshapeInfo : public OperatorInfo {
|
|||
int64_t pre_operator_index() const { return pre_operator_index_; }
|
||||
int64_t next_operator_index() const { return next_operator_index_; }
|
||||
|
||||
int64_t GetSWCIndexByOutputLayoutWithZeroComm(const TensorLayout &);
|
||||
int64_t GetSWCIndexByOutputLayout(const TensorLayout &);
|
||||
int64_t GetSWCIndexByInputLayoutWithZeroComm(const TensorLayout &);
|
||||
int64_t GetSWCIndexByInputLayout(const TensorLayout &);
|
||||
bool CheckStrategyConsistencyByOutputLayout(int64_t, const TensorLayout &);
|
||||
bool CheckStrategyConsistencyByInputLayout(int64_t, const TensorLayout &);
|
||||
|
||||
TensorLayout GetInputLayoutBySWCIndex(int64_t);
|
||||
TensorLayout GetOutputLayoutBySWCIndex(int64_t);
|
||||
|
||||
protected:
|
||||
Status CheckStrategy(const StrategyPtr &strategy) override;
|
||||
Status InferMirrorOps() override;
|
||||
|
|
|
@ -665,6 +665,14 @@ void CreateEdgeBetweenTwoOps(const OperatorInfoPtr &prev_op_info, const Operator
|
|||
(*edge_count)++;
|
||||
}
|
||||
|
||||
void CheckAndApplyApproximation() {
|
||||
// If 'approximation' is enabled, the edges need to be checked have effective costs.
|
||||
auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
|
||||
if (approximation) {
|
||||
entire_costgraph->CheckApproximateCostGraphEdges();
|
||||
}
|
||||
}
|
||||
|
||||
void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
|
||||
// Step 2
|
||||
MS_LOG(INFO) << "Constructing edges for cost graph begins.";
|
||||
|
@ -736,11 +744,7 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
}
|
||||
MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name();
|
||||
}
|
||||
// If 'approximation' is enabled, the edges need to be checked have effective costs.
|
||||
auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
|
||||
if (approximation) {
|
||||
entire_costgraph->CheckApproximateCostGraphEdges();
|
||||
}
|
||||
CheckAndApplyApproximation();
|
||||
|
||||
MS_LOG(INFO) << "Constructing edges for cost graph ends.";
|
||||
}
|
||||
|
@ -913,7 +917,7 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
reshape_info->set_next_operator_index(in_index);
|
||||
}
|
||||
bool is_prev_param = pre_node->isa<Parameter>();
|
||||
if (reshape_info->GenetateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param,
|
||||
if (reshape_info->GenerateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param,
|
||||
is_next_reshape) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "reshape generate strategy_costs failed!";
|
||||
}
|
||||
|
|
|
@ -0,0 +1,336 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common import Parameter
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x):
|
||||
predict = self.network(x)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x):
|
||||
return grad_all(self.network)(x)
|
||||
|
||||
class NetWithLossTwoInput(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLossTwoInput, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
predict = self.network(x, y)
|
||||
return self.loss(predict)
|
||||
|
||||
class GradWrapTwoInput(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrapTwoInput, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
return grad_all(self.network)(x, y)
|
||||
|
||||
|
||||
def compile_graph(net, device_num, x):
|
||||
context.set_auto_parallel_context(device_num=device_num, global_rank=0, parallel_mode="auto_parallel",
|
||||
sharding_propagation=True)
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
_cell_graph_executor.compile(net, x)
|
||||
|
||||
|
||||
def compile_graph_two_input(net, device_num, x, y):
|
||||
context.set_auto_parallel_context(device_num=device_num, global_rank=0, parallel_mode="auto_parallel",
|
||||
sharding_propagation=True)
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
_cell_graph_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_reshape_reshape():
|
||||
"""
|
||||
Feature: Sharding propagation for Reshape.
|
||||
Description: ReLU->Reshape
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
device_num = 8
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.relu = P.ReLU().shard(((1, 1, 1, 1),))
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu(x)
|
||||
out = self.reshape(x, (64, 28))
|
||||
out = self.reshape(out, (64, 28, 1))
|
||||
return out
|
||||
|
||||
x = Tensor(np.ones([device_num * 8, 28, 1, 1]), dtype=ms.float32)
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
compile_graph(net, device_num, x)
|
||||
|
||||
|
||||
def test_reshape_auto_1():
|
||||
"""
|
||||
Feature: Sharding propagation for Reshape.
|
||||
Description: ReLU->Reshape->MatMul
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
device_num = 8
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.relu = P.ReLU().shard(((1, 1, 1, 1),))
|
||||
self.reshape = P.Reshape()
|
||||
self.matmul = P.MatMul().shard(((2, 1), (1, 4)))
|
||||
self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight")
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu(x)
|
||||
out = self.reshape(x, (64, 28))
|
||||
out = self.matmul(out, self.matmul_weight)
|
||||
return out
|
||||
|
||||
x = Tensor(np.ones([device_num * 8, 28, 1, 1]), dtype=ms.float32)
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
compile_graph(net, device_num, x)
|
||||
|
||||
|
||||
def test_reshape_auto_2():
|
||||
"""
|
||||
Feature: Sharding propagation for Reshape.
|
||||
Description: ReLU->Reshape->MatMul->Reshape->Add
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
device_num = 8
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.relu = P.ReLU()
|
||||
self.relu2 = P.ReLU()
|
||||
self.reshape = P.Reshape()
|
||||
self.matmul = P.MatMul().shard(((2, 1), (1, 4)))
|
||||
self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight")
|
||||
self.add = P.Add().shard(((2, 4), (2, 4)))
|
||||
self.add_weight = Parameter(Tensor(np.ones([128, 32]), dtype=ms.float32), name="weight1")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.relu(x)
|
||||
out = self.relu2(out)
|
||||
out = self.reshape(out, (64, 28))
|
||||
out = self.matmul(out, self.matmul_weight)
|
||||
out = self.reshape(out, (128, 32))
|
||||
out = self.add(out, self.add_weight)
|
||||
return out
|
||||
|
||||
x = Tensor(np.ones([device_num * 8, 28, 1, 1]), dtype=ms.float32)
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
compile_graph(net, device_num, x)
|
||||
|
||||
|
||||
def test_reshape_auto_3():
|
||||
"""
|
||||
Feature: Sharding propagation for Reshape.
|
||||
Description: Mul->Add->Cast->Reshape->Cast->ReduceMean
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
device_num = 8
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gamma = Parameter(Tensor(np.ones([1024]), dtype=ms.float32), name="gamma")
|
||||
self.beta = Parameter(Tensor(np.ones([1024]), dtype=ms.float32), name="beta")
|
||||
self.add = P.TensorAdd().shard(((8, 1, 1), (1,)))
|
||||
self.mul = P.Mul().shard(((8, 1, 1), (1,)))
|
||||
self.mean = P.ReduceMean(keep_dims=True).shard(((8, 1),))
|
||||
self.reshape = P.Reshape()
|
||||
self.dtype1 = mstype.float16
|
||||
self.dtype2 = mstype.float32
|
||||
|
||||
def construct(self, x):
|
||||
out = self.add(self.mul(x, self.gamma), self.beta)
|
||||
out = F.cast(out, self.dtype1)
|
||||
out = self.reshape(out, (-1, 1024))
|
||||
out = F.cast(out, self.dtype2)
|
||||
out = self.mean(out, -1)
|
||||
return out
|
||||
|
||||
x = Tensor(np.ones([2048, 30, 1024]), dtype=ms.float32)
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
compile_graph(net, device_num, x)
|
||||
|
||||
|
||||
def test_reshape_auto_4():
|
||||
"""
|
||||
Feature: Sharding propagation for Reshape.
|
||||
Description: Mul->Add->Cast->Reshape->Cast->ReduceMean
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
device_num = 8
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gamma = Parameter(Tensor(np.ones([1024]), dtype=ms.float32), name="gamma")
|
||||
self.beta = Parameter(Tensor(np.ones([1024]), dtype=ms.float32), name="beta")
|
||||
self.add = P.TensorAdd().shard(((8, 1, 1), (1,)))
|
||||
self.mul = P.Mul().shard(((8, 1, 1), (1,)))
|
||||
self.mean = P.ReduceMean(keep_dims=True)
|
||||
self.reshape = P.Reshape()
|
||||
self.dtype1 = mstype.float16
|
||||
self.dtype2 = mstype.float32
|
||||
|
||||
def construct(self, x):
|
||||
out = self.add(self.mul(x, self.gamma), self.beta)
|
||||
out = F.cast(out, self.dtype1)
|
||||
out = self.reshape(out, (-1, 1024))
|
||||
out = F.cast(out, self.dtype2)
|
||||
out = self.mean(out, -1)
|
||||
return out
|
||||
|
||||
x = Tensor(np.ones([2048, 30, 1024]), dtype=ms.float32)
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
compile_graph(net, device_num, x)
|
||||
|
||||
|
||||
def test_reshape_auto_5():
|
||||
"""
|
||||
Feature: Sharding propagation for Reshape.
|
||||
Description: Mul->Add->Cast->Reshape->Cast->ReduceMean
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
device_num = 8
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gamma = Parameter(Tensor(np.ones([1024]), dtype=ms.float32), name="gamma")
|
||||
self.beta = Parameter(Tensor(np.ones([1024]), dtype=ms.float32), name="beta")
|
||||
self.add = P.TensorAdd().shard(((8, 1, 1), (1,)))
|
||||
self.mul = P.Mul()
|
||||
self.mean = P.ReduceMean(keep_dims=True).shard(((2, 4),))
|
||||
self.reshape = P.Reshape()
|
||||
self.dtype1 = mstype.float16
|
||||
self.dtype2 = mstype.float32
|
||||
|
||||
def construct(self, x):
|
||||
out = self.add(self.mul(x, self.gamma), self.beta)
|
||||
out = self.reshape(out, (-1, 1024))
|
||||
out = self.mean(out, -1)
|
||||
return out
|
||||
|
||||
x = Tensor(np.ones([2048, 30, 1024]), dtype=ms.float32)
|
||||
net = GradWrap(NetWithLoss(Net()))
|
||||
compile_graph(net, device_num, x)
|
||||
|
||||
|
||||
def test_reshape_auto_6():
|
||||
"""
|
||||
Feature: Sharding propagation for Reshape.
|
||||
Description: Reshape->ReLU->Mul->Reshape->Add->Mul->Reshape->Add
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
device_num = 8
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.relu = P.ReLU()
|
||||
self.mul = P.Mul().shard(((8, 1, 1), (8, 1, 1)))
|
||||
self.reshape = P.Reshape()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.wide_w = Parameter(Tensor(np.ones([8, 1024*8, 64]), dtype=ms.float32), name="weight")
|
||||
|
||||
def construct(self, x, y):
|
||||
mask = self.reshape(y, (8, 1024*8, 1))
|
||||
w_id = self.relu(x)
|
||||
wx = self.mul(w_id, mask)
|
||||
wide_out = self.reshape(self.reduce_sum(wx, 1), (-1, 1))
|
||||
deep_id = x + self.wide_w
|
||||
vx = self.mul(deep_id, mask)
|
||||
deep_in = self.reshape(vx, (-1, 1024*8*64))
|
||||
out = wide_out + deep_in
|
||||
return out
|
||||
|
||||
x = Tensor(np.ones([8, 1024*device_num, 1]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([8, 1024*device_num]), dtype=ms.float32)
|
||||
net = GradWrapTwoInput(NetWithLossTwoInput(Net()))
|
||||
compile_graph_two_input(net, device_num, x, y)
|
||||
|
||||
|
||||
def test_reshape_depend_reshape():
|
||||
"""
|
||||
Feature: Sharding propagation for Reshape.
|
||||
Description: Mul->ReLU->Reshape->Reshape->Add
|
||||
Expectation: compile with error.
|
||||
"""
|
||||
device_num = 8
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.reshape1 = P.Reshape()
|
||||
self.reshape2 = P.Reshape()
|
||||
self.relu = P.ReLU()
|
||||
self.depend = P.Depend()
|
||||
self.mul = P.Mul().shard(((2, 4), (2, 4)))
|
||||
self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight")
|
||||
self.add = P.Add().shard(((4, 2), (4, 2)))
|
||||
|
||||
def construct(self, x, y):
|
||||
out1 = self.mul(x, self.mul_weight)
|
||||
y = self.relu(y)
|
||||
out2 = self.reshape1(y, (96, 32, 4))
|
||||
out3 = self.depend(out2, out1)
|
||||
out3 = self.reshape2(out3, (128, 96))
|
||||
out = out1 + out3
|
||||
return out
|
||||
|
||||
class NetWithLoss1(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss1, self).__init__()
|
||||
self.mean = P.ReduceMean(keep_dims=False)
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
predict = self.network(x, y)
|
||||
return self.mean(predict, ())
|
||||
|
||||
x = Tensor(np.ones([128, 96]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([256, 48]), dtype=ms.float32)
|
||||
net = GradWrapTwoInput(NetWithLoss1(Net()))
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_graph_two_input(net, device_num, x, y)
|
Loading…
Reference in New Issue