From a772767265f7b369b8ae81ce52d4ed863420df60 Mon Sep 17 00:00:00 2001 From: Xiaoda Zhang Date: Tue, 9 Nov 2021 15:39:02 +0800 Subject: [PATCH] support reshape in sharding propagation: 1) using 'swc index of strategy_cost_' as reshape's selected strategy; 2) when encountering reshape in BFS, select the 'swc index' with zero communication cost; 3) when encountering a reshape that is already visited, check whether there exists communication between reshape and current operator. It is OK if communication happens between two configured operators; 4) currently, two consecutive reshapes are not supported; 5) adjusting BFS structure in graph_costmodel.cc; 6) adjusting some code in step_auto_parallel.cc to avoid cyclomatic complexity. --- .../parallel/auto_parallel/edge_costmodel.cc | 117 ++++++ .../parallel/auto_parallel/edge_costmodel.h | 11 +- .../parallel/auto_parallel/graph_costmodel.cc | 163 ++++++--- .../parallel/ops_info/operator_info.cc | 50 +++ .../parallel/ops_info/operator_info.h | 9 + .../parallel/ops_info/reshape_info.cc | 132 ++++++- .../frontend/parallel/ops_info/reshape_info.h | 12 +- .../frontend/parallel/step_auto_parallel.cc | 16 +- .../test_reshape_shard_propagation.py | 336 ++++++++++++++++++ 9 files changed, 792 insertions(+), 54 deletions(-) create mode 100644 tests/ut/python/parallel/test_reshape_shard_propagation.py diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc index bedc33f81c3..a21ccff9db8 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc @@ -23,6 +23,7 @@ #include "frontend/parallel/auto_parallel/costmodel.h" #include "frontend/parallel/auto_parallel/graph_costmodel.h" #include "frontend/parallel/tensor_layout/tensor_redistribution.h" +#include "frontend/parallel/ops_info/reshape_info.h" namespace mindspore { namespace parallel { @@ -394,6 +395,122 @@ StrategyPtr Edge::GetPrevOpStrategyByNextOpStrategyWithZeroComm(const StrategyPt return prev_op_stras[0].first; } +int64_t Edge::GetReshapeSWCIndexByNextOpStrategy(const StrategyPtr &next_op_stra, int64_t curr_depth, + const std::map &configured_ops) { + if (prev_op_->name().find(RESHAPEINFO) == std::string::npos) { + MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s prev_op is not a Reshape."; + } + if (next_op_->name().find(RESHAPEINFO) != std::string::npos) { + MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently."; + } + const auto &reshape_output_layout = next_op_->GetInputLayoutFromSWCByStrategy(next_op_stra, next_op_input_index_); + MS_LOG(INFO) << prev_op_->name() << "'s output layout: " << reshape_output_layout.ToString(); + auto reshape_ptr = std::dynamic_pointer_cast(prev_op_); + // First, try to find the zero communication strategy. + auto swc_index = reshape_ptr->GetSWCIndexByOutputLayoutWithZeroComm(reshape_output_layout); + if (swc_index == -1 && curr_depth == 0) { + const auto &prev_edges = reshape_ptr->prev_edges(); + if (!prev_edges.empty()) { + auto prev_edge = prev_edges[0]; + if (configured_ops.find(prev_edge->prev_operator()) != configured_ops.end()) { + // Here, it is sure that Reshape's previous and next operators are both configured strategies, + // thus, it is OK that communication happens here. + swc_index = reshape_ptr->GetSWCIndexByOutputLayout(reshape_output_layout); + } + } + } + if (swc_index == -1) { + MS_LOG(EXCEPTION) << "No available strategy found at edge: " << edge_name_ << " for: " << prev_op_->name(); + } + return swc_index; +} + +int64_t Edge::GetReshapeSWCIndexByPrevOpStrategy(const StrategyPtr &prev_op_stra, int64_t curr_depth, + const std::map &configured_ops) { + if (next_op_->name().find(RESHAPEINFO) == std::string::npos) { + MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s next_op is not a Reshape."; + } + if (prev_op_->name().find(RESHAPEINFO) != std::string::npos) { + MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently."; + } + const auto &reshape_input_lyt = prev_op_->GetOutputLayoutFromSWCByStrategy(prev_op_stra, prev_op_output_index_); + MS_LOG(INFO) << next_op_->name() << "'s input layout: " << reshape_input_lyt.ToString(); + auto reshape_ptr = std::dynamic_pointer_cast(next_op_); + // First, try to find the zero communication strategy. + auto swc_index = reshape_ptr->GetSWCIndexByInputLayoutWithZeroComm(reshape_input_lyt); + if (swc_index == -1 && curr_depth == 0) { + const auto &next_edges = reshape_ptr->succ_edges(); + if (!next_edges.empty()) { + auto next_edge = next_edges[0]; + if (configured_ops.find(next_edge->next_operator()) != configured_ops.end()) { + // Here, it is sure that Reshape's previous and next operators are both configured strategies, + // thus, it is OK that communication happens here. + swc_index = reshape_ptr->GetSWCIndexByInputLayout(reshape_input_lyt); + } + } + } + if (swc_index == -1) { + MS_LOG(EXCEPTION) << "No available strategy found at edge: " << edge_name_ << " for: " << next_op_->name(); + } + return swc_index; +} + +StrategyPtr Edge::GetPrevOpStrategyByReshapeSWCIndex(int64_t swc_index) { + if (next_op_->name().find(RESHAPEINFO) == std::string::npos) { + MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s next_op is not a Reshape."; + } + if (prev_op_->name().find(RESHAPEINFO) != std::string::npos) { + MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently."; + } + auto reshape_ptr = std::dynamic_pointer_cast(next_op_); + const auto &reshape_input_lyt = reshape_ptr->GetInputLayoutBySWCIndex(swc_index); + auto stra = prev_op_->GetStrategyFromSWCByOutputLayout(reshape_input_lyt, prev_op_output_index_); + if (stra == nullptr) { + MS_LOG(EXCEPTION) << "No available strategy found at edge: " << edge_name_ << " for: " << prev_op_->name(); + } + return stra; +} + +StrategyPtr Edge::GetNextOpStrategyByReshapeSWCIndex(int64_t swc_index) { + if (prev_op_->name().find(RESHAPEINFO) == std::string::npos) { + MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s next_op is not a Reshape."; + } + if (next_op_->name().find(RESHAPEINFO) != std::string::npos) { + MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently."; + } + auto reshape_ptr = std::dynamic_pointer_cast(prev_op_); + const auto &reshape_output_lyt = reshape_ptr->GetOutputLayoutBySWCIndex(swc_index); + auto stra = next_op_->GetStrategyFromSWCByInputLayout(reshape_output_lyt, next_op_input_index_); + if (stra == nullptr) { + MS_LOG(EXCEPTION) << "No available strategy found at edge: " << edge_name_ << " for: " << prev_op_->name(); + } + return stra; +} + +bool Edge::CheckStrategyConsistency(const std::map &configured_ops) { + auto prev_stra = prev_op_->selected_strategy(); + auto next_stra = next_op_->selected_strategy(); + if (prev_stra == nullptr) { + MS_LOG(EXCEPTION) << prev_op_->name() << "'s selected strategy is null!"; + } + if (next_op_ == nullptr) { + MS_LOG(EXCEPTION) << next_op_->name() << "'s selected strategy is null!"; + } + auto cost = GetCostByStrategyPair({prev_stra, next_stra}); + if ((configured_ops.find(prev_op_) == configured_ops.end() || + configured_ops.find(next_op_) == configured_ops.end()) && + (cost == nullptr || cost->communication_cost_ > 0.0)) { + PrintStrategy(prev_op_->selected_strategy()); + PrintStrategy(next_op_->selected_strategy()); + MS_LOG(ERROR) << "There are redistribution cost occurs at edge: " << edge_name() + << ", consider configuring sharding strategies for two operators." + << " The full name of these two operators are: " << prev_op_->cnode()->fullname_with_scope() + << " and " << next_op_->cnode()->fullname_with_scope(); + return false; + } + return true; +} + void Edge::SetCostMapAndInputOutput(std::map &cost_map) { cost_map_ = cost_map; pre_op_output_.clear(); diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h index 9326df941cb..4bec3840c94 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h @@ -82,8 +82,17 @@ class Edge { Status InitEdgeCost(); std::map GetCostMap() { return cost_map_; } CostPtr GetCostByStrategyPair(const CostPtrKey &); + StrategyPtr GetNextOpStrategyByPrevOpStrategyWithZeroComm(const StrategyPtr &); StrategyPtr GetPrevOpStrategyByNextOpStrategyWithZeroComm(const StrategyPtr &); + int64_t GetReshapeSWCIndexByNextOpStrategy(const StrategyPtr &next_op_stra, int64_t curr_depth, + const std::map &configured_ops); + int64_t GetReshapeSWCIndexByPrevOpStrategy(const StrategyPtr &prev_op_stra, int64_t curr_depth, + const std::map &configured_ops); + StrategyPtr GetPrevOpStrategyByReshapeSWCIndex(int64_t swc_index); + StrategyPtr GetNextOpStrategyByReshapeSWCIndex(int64_t swc_index); + bool CheckStrategyConsistency(const std::map &configured_ops); + void SetCostMapAndInputOutput(std::map &); // For two operators u--->v, given the output tensor layout of u, // and the input tensor layout of v, return the redistribution cost, @@ -166,7 +175,7 @@ class Edge { // operators. The resulting edges are different from those normal edges, thus this Bool variable distinguishes them. // If it is true, then we should guarantee that the strategy for output tensor consistent with the input tensor. bool is_identity_edge; - CostPtr selected_cost_; + CostPtr selected_cost_ = nullptr; // In the training phase, 'is_output_parameter_involve_' is used to mark whether the output of the previous operator // is parameter-involved int64_t is_output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc index 71f70ab0cd2..66180a9df0f 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc @@ -101,84 +101,157 @@ void CostGraph::StrategyPropagate(const std::map & } } -void CheckShardingConsisitency(std::map configured_ops, const OperatorInfoPtr &curr_op, - const OperatorInfoPtr &another_op, const CostPtr &cost, const EdgePtr &edge) { - if ((configured_ops.find(another_op) == configured_ops.end()) && - (cost == nullptr || cost->communication_cost_ != 0.0)) { - PrintStrategy(another_op->selected_strategy()); - PrintStrategy(curr_op->selected_strategy()); - MS_LOG(EXCEPTION) << "There are redistribution cost occurs at edge: " << edge->edge_name() - << ", consider configuring sharding strategies for two operators." - << " The full name of these two operators are: " << curr_op->cnode()->fullname_with_scope() - << " and " << another_op->cnode()->fullname_with_scope(); +void CheckVisitedEdgeConsistency(const EdgePtr &edge, std::map configured_ops) { + auto prev_op = edge->prev_operator(); + auto next_op = edge->next_operator(); + if (prev_op->name().find(RESHAPEINFO) != std::string::npos) { + const auto &reshape_output_lyt = + next_op->GetInputLayoutFromSWCByStrategy(next_op->selected_strategy(), edge->next_op_input_index()); + auto reshape_ptr = std::dynamic_pointer_cast(prev_op); + auto consistency = + reshape_ptr->CheckStrategyConsistencyByOutputLayout(reshape_ptr->swc_index(), reshape_output_lyt); + if (!consistency) { + MS_LOG(EXCEPTION) << "Inconsistency occurred at edge: " << edge->edge_name(); + } + } else if (next_op->name().find(RESHAPEINFO) != std::string::npos) { + const auto &reshape_input_lyt = + prev_op->GetOutputLayoutFromSWCByStrategy(prev_op->selected_strategy(), edge->prev_op_output_index()); + auto reshape_ptr = std::dynamic_pointer_cast(next_op); + auto consistency = reshape_ptr->CheckStrategyConsistencyByInputLayout(reshape_ptr->swc_index(), reshape_input_lyt); + if (!consistency) { + MS_LOG(EXCEPTION) << "Inconsistency occurred at edge: " << edge->edge_name(); + } + } else { + auto consistency = edge->CheckStrategyConsistency(configured_ops); + if (!consistency) { + MS_LOG(EXCEPTION) << "Inconsistency occurred at edge: " << edge->edge_name(); + } + } +} + +void CheckConfiguredSuccEdgeConsistency(const EdgePtr edge, std::map configured_ops, + int64_t curr_depth) { + auto curr_op = edge->prev_operator(); + auto next_op = edge->next_operator(); + if ((curr_op->name().find(RESHAPEINFO) != std::string::npos) && curr_depth > 1) { + auto next_op_conf_stra = configured_ops[next_op]; + const auto &reshape_output_lyt = + next_op->GetInputLayoutFromSWCByStrategy(next_op_conf_stra, edge->next_op_input_index()); + auto reshape_ptr = std::dynamic_pointer_cast(curr_op); + auto consistency = + reshape_ptr->CheckStrategyConsistencyByOutputLayout(reshape_ptr->swc_index(), reshape_output_lyt); + if (!consistency) { + MS_LOG(EXCEPTION) << "Inconsistency occurred at edge: " << edge->edge_name(); + } + } else if (curr_op->name().find(RESHAPEINFO) == std::string::npos) { + const auto &next_op_conf_stra = configured_ops[next_op]; + const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithZeroComm(curr_op->selected_strategy()); + if ((next_op_conf_stra == nullptr) || (!next_op_conf_stra->IsEqual(next_op_stra))) { + MS_LOG(EXCEPTION) << "Sharding strategies should be configured on the boundary operators. " + << "Currently reaching " << curr_op->name() << " and " << next_op->name() << "." + << " The full name of these two operators are: " << curr_op->cnode()->fullname_with_scope() + << " and " << next_op->cnode()->fullname_with_scope(); + } + } +} + +void CheckConfiguredPrevEdgeConsistency(const EdgePtr edge, std::map configured_ops, + int64_t curr_depth) { + auto curr_op = edge->next_operator(); + auto prev_op = edge->prev_operator(); + if (curr_op->name().find(RESHAPEINFO) != std::string::npos && curr_depth > 1) { + auto prev_op_conf_stra = configured_ops[prev_op]; + const auto &reshape_input_lyt = + prev_op->GetOutputLayoutFromSWCByStrategy(prev_op_conf_stra, edge->prev_op_output_index()); + auto reshape_ptr = std::dynamic_pointer_cast(curr_op); + auto consistency = reshape_ptr->CheckStrategyConsistencyByInputLayout(reshape_ptr->swc_index(), reshape_input_lyt); + if (!consistency) { + MS_LOG(EXCEPTION) << "Inconsistency occurred at edge: " << edge->edge_name(); + } + } else if (curr_op->name().find(RESHAPEINFO) == std::string::npos) { + const auto &prev_op_conf_stra = configured_ops[prev_op]; + const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithZeroComm(curr_op->selected_strategy()); + if ((prev_op_conf_stra == nullptr) || (!prev_op_conf_stra->IsEqual(prev_op_stra))) { + MS_LOG(ERROR) << "curr_depth: " << curr_depth; + MS_LOG(EXCEPTION) << "Sharding strategies should be configured on the boundary operators. " + << "Currently reaching " << prev_op->name() << " and " << curr_op->name() << "." + << " The full name of these two operators are: " << prev_op->cnode()->fullname_with_scope() + << " and " << curr_op->cnode()->fullname_with_scope(); + } } } void CostGraph::BFS(const OperatorInfoPtr &op, const StrategyPtr &op_stra, std::map configured_ops, std::map *visited) { - std::queue, size_t>> next_level; - (void)next_level.emplace(std::make_pair(op, op_stra), 0); + std::queue>, int64_t>> next_level; + (void)next_level.emplace(std::make_pair(op, std::make_pair(op_stra, -1)), 0); while (!next_level.empty()) { auto curr_op = next_level.front().first.first; - auto configured_stra = next_level.front().first.second; + auto configured_stra = next_level.front().first.second.first; + auto configured_stra_index = next_level.front().first.second.second; auto curr_depth = next_level.front().second; visited->at(curr_op) = true; MS_LOG(INFO) << "curr_depth: " << curr_depth; - curr_op->SetSelectedStrategy(configured_stra, curr_depth); + if (curr_op->name().find(RESHAPEINFO) != std::string::npos) { + curr_op->set_swc_index(configured_stra_index, curr_depth); + } else { + curr_op->SetSelectedStrategy(configured_stra, curr_depth); + } for (auto &edge : curr_op->succ_edges()) { const auto &next_op = edge->next_operator(); if (visited->at(next_op)) { - const auto cost = edge->GetCostByStrategyPair({curr_op->selected_strategy(), next_op->selected_strategy()}); - CheckShardingConsisitency(configured_ops, curr_op, next_op, cost, edge); + CheckVisitedEdgeConsistency(edge, configured_ops); continue; } if ((curr_depth > 0) && (configured_ops.find(next_op) != configured_ops.end())) { - const auto &next_op_conf_stra = configured_ops[next_op]; - const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithZeroComm(curr_op->selected_strategy()); - if ((next_op_conf_stra == nullptr) || (!next_op_conf_stra->IsEqual(next_op_stra))) { - MS_LOG(EXCEPTION) << "Sharding strategies should be configured on the boundary operators. " - << "Currently reaching " << curr_op->name() << " and " << next_op->name() << "." - << " The full name of these two operators are: " << curr_op->cnode()->fullname_with_scope() - << " and " << next_op->cnode()->fullname_with_scope(); - } + CheckConfiguredSuccEdgeConsistency(edge, configured_ops, curr_depth); } if (configured_ops.find(next_op) != configured_ops.end()) { continue; } - const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithZeroComm(curr_op->selected_strategy()); - if (next_op_stra == nullptr) { - PrintStrategy(curr_op->selected_strategy()); - MS_LOG(EXCEPTION) << next_op->name() << "'s strategy is null in the edge: " << edge->edge_name(); + if (curr_op->name().find(RESHAPEINFO) != std::string::npos) { + auto stra = edge->GetNextOpStrategyByReshapeSWCIndex(curr_op->swc_index()); + (void)next_level.emplace(std::make_pair(next_op, std::make_pair(stra, -1)), curr_depth + 1); + } else if (next_op->name().find(RESHAPEINFO) != std::string::npos) { + auto swc_index = + edge->GetReshapeSWCIndexByPrevOpStrategy(curr_op->selected_strategy(), curr_depth, configured_ops); + (void)next_level.emplace(std::make_pair(next_op, std::make_pair(nullptr, swc_index)), curr_depth + 1); + } else { + const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithZeroComm(curr_op->selected_strategy()); + if (next_op_stra == nullptr) { + PrintStrategy(curr_op->selected_strategy()); + MS_LOG(EXCEPTION) << next_op->name() << "'s strategy is null in the edge: " << edge->edge_name(); + } + (void)next_level.emplace(std::make_pair(next_op, std::make_pair(next_op_stra, -1)), curr_depth + 1); } - (void)next_level.emplace(std::make_pair(next_op, next_op_stra), curr_depth + 1); } for (auto &edge : curr_op->prev_edges()) { const auto &prev_op = edge->prev_operator(); if (visited->at(prev_op)) { - const auto cost = edge->GetCostByStrategyPair({prev_op->selected_strategy(), curr_op->selected_strategy()}); - CheckShardingConsisitency(configured_ops, curr_op, prev_op, cost, edge); + CheckVisitedEdgeConsistency(edge, configured_ops); continue; } if ((curr_depth > 0) && (configured_ops.find(prev_op) != configured_ops.end())) { - const auto &prev_op_conf_stra = configured_ops[prev_op]; - const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithZeroComm(curr_op->selected_strategy()); - if ((prev_op_conf_stra == nullptr) || (!prev_op_conf_stra->IsEqual(prev_op_stra))) { - MS_LOG(ERROR) << "curr_depth: " << curr_depth; - MS_LOG(EXCEPTION) << "Sharding strategies should be configured on the boundary operators. " - << "Currently reaching " << prev_op->name() << " and " << curr_op->name() << "." - << " The full name of these two operators are: " << prev_op->cnode()->fullname_with_scope() - << " and " << curr_op->cnode()->fullname_with_scope(); - } + CheckConfiguredPrevEdgeConsistency(edge, configured_ops, curr_depth); } if (configured_ops.find(prev_op) != configured_ops.end()) { continue; } - const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithZeroComm(curr_op->selected_strategy()); - if (prev_op_stra == nullptr) { - PrintStrategy(curr_op->selected_strategy()); - MS_LOG(EXCEPTION) << prev_op->name() << "'s strategy is null in the edge: " << edge->edge_name(); + if (prev_op->name().find(RESHAPEINFO) != std::string::npos) { + auto swc_index = + edge->GetReshapeSWCIndexByNextOpStrategy(curr_op->selected_strategy(), curr_depth, configured_ops); + (void)next_level.emplace(std::make_pair(prev_op, std::make_pair(nullptr, swc_index)), curr_depth + 1); + } else if (curr_op->name().find(RESHAPEINFO) != std::string::npos) { + auto prev_stra = edge->GetPrevOpStrategyByReshapeSWCIndex(curr_op->swc_index()); + (void)next_level.emplace(std::make_pair(prev_op, std::make_pair(prev_stra, -1)), curr_depth + 1); + } else { + const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithZeroComm(curr_op->selected_strategy()); + if (prev_op_stra == nullptr) { + PrintStrategy(curr_op->selected_strategy()); + MS_LOG(EXCEPTION) << prev_op->name() << "'s strategy is null in the edge: " << edge->edge_name(); + } + (void)next_level.emplace(std::make_pair(prev_op, std::make_pair(prev_op_stra, -1)), curr_depth + 1); } - (void)next_level.emplace(std::make_pair(prev_op, prev_op_stra), curr_depth + 1); } next_level.pop(); } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc index e4496ce6300..70b7447a850 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc @@ -1363,6 +1363,50 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) { return SUCCESS; } +TensorLayout OperatorInfo::GetInputLayoutFromSWCByStrategy(StrategyPtr stra, size_t input_index) { + auto is_target = [&](std::shared_ptr swc) { return swc->strategy_ptr->IsEqual(stra); }; + auto it = std::find_if(strategy_cost_.begin(), strategy_cost_.end(), is_target); + if (it != strategy_cost_.end()) { + const auto &input_info = (*it)->inputs_ptr[input_index]; + return std::move(input_info.tensor_layout()); + } + TensorLayout empty; + return empty; +} + +TensorLayout OperatorInfo::GetOutputLayoutFromSWCByStrategy(StrategyPtr stra, size_t output_index) { + auto is_target = [&](std::shared_ptr swc) { return swc->strategy_ptr->IsEqual(stra); }; + auto it = std::find_if(strategy_cost_.begin(), strategy_cost_.end(), is_target); + if (it != strategy_cost_.end()) { + const auto &output_info = (*it)->outputs_ptr[output_index]; + return std::move(output_info.tensor_layout()); + } + TensorLayout empty; + return empty; +} + +StrategyPtr OperatorInfo::GetStrategyFromSWCByInputLayout(TensorLayout input_layout, size_t input_index) { + auto is_target = [&](std::shared_ptr swc) { + return swc->inputs_ptr[input_index].tensor_layout() == input_layout; + }; + auto it = std::find_if(strategy_cost_.begin(), strategy_cost_.end(), is_target); + if (it != strategy_cost_.end()) { + return (*it)->strategy_ptr; + } + return nullptr; +} + +StrategyPtr OperatorInfo::GetStrategyFromSWCByOutputLayout(TensorLayout output_layout, size_t output_index) { + auto is_target = [&](std::shared_ptr swc) { + return swc->outputs_ptr[output_index].tensor_layout() == output_layout; + }; + auto it = std::find_if(strategy_cost_.begin(), strategy_cost_.end(), is_target); + if (it != strategy_cost_.end()) { + return (*it)->strategy_ptr; + } + return nullptr; +} + // Keep at most (1.0 / epsilon) number of available strategies for each operator. void OperatorInfo::ApproximateStrategies() { auto enable_approxi = CostModelContext::GetInstance()->dp_algo_enable_approxi(); @@ -1683,6 +1727,12 @@ void OperatorInfo::SetSelectedStrategy(const StrategyPtr &s_strategy, size_t cur selected_strategy_depth_ = SizeToLong(curr_depth); } +void OperatorInfo::set_swc_index(int64_t swc, int64_t depth) { + MS_LOG(INFO) << "Set SWC index: " << swc << " for: " << name(); + selected_strategy_depth_ = depth; + swc_index_ = swc; +} + CNodePtr OperatorInfo::cnode() { MS_EXCEPTION_IF_NULL(cnode_); return cnode_; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h index 8513fbbb017..5e8551aa696 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h @@ -144,6 +144,14 @@ class OperatorInfo { void SetSelectedStrategy(const StrategyPtr &s_strategy, size_t); StrategyPtr selected_strategy() const { return selected_strategy_; } CostPtr selected_cost() const { return selected_cost_; } + + TensorLayout GetInputLayoutFromSWCByStrategy(StrategyPtr stra, size_t input_index); + TensorLayout GetOutputLayoutFromSWCByStrategy(StrategyPtr stra, size_t output_index); + StrategyPtr GetStrategyFromSWCByInputLayout(TensorLayout input_layout, size_t input_index); + StrategyPtr GetStrategyFromSWCByOutputLayout(TensorLayout output_layout, size_t output_index); + + void set_swc_index(int64_t, int64_t); + int64_t swc_index() { return swc_index_; } // Approximate the list of available strategies void ApproximateStrategies(); // Make the list of available strategies exact and re-init the related edges incident to this operator @@ -293,6 +301,7 @@ class OperatorInfo { private: OperatorCostPtr operator_cost_; std::vector outputs_type_; + int64_t swc_index_ = -1; }; Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc index 3841ee8c709..025feeb210f 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc @@ -431,7 +431,7 @@ std::vector ReshapeInfo::GenerateOpStrategies(int64_t) { return sp_vector; } -Status ReshapeInfo::GenetateStrategyCosts(const std::vector> &pre_stra_costs, +Status ReshapeInfo::GenerateStrategyCosts(const std::vector> &pre_stra_costs, const std::vector> &next_stra_costs, int64_t out_index, int64_t in_index, bool is_prev_param, bool is_next_reshape) { @@ -488,7 +488,137 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vectorstrategy_ptr); + MS_LOG(INFO) << "The corresponding cost: " << swc->cost_list[0]->computation_cost_ << ", " + << swc->cost_list[0]->communication_cost_ << ", " + << swc->cost_list[0]->communication_without_parameter_; + MS_LOG(INFO) << "Input layout: " << swc->inputs_ptr[0].tensor_layout().ToString(); + MS_LOG(INFO) << "Output layout: " << swc->outputs_ptr[0].tensor_layout().ToString(); + } return SUCCESS; } + +int64_t ReshapeInfo::GetSWCIndexByOutputLayoutWithZeroComm(const TensorLayout &output_layout) { + std::vector> index_computation; + for (size_t i = 0; i < strategy_cost_.size(); ++i) { + const auto &swc = strategy_cost_[i]; + if (swc->outputs_ptr[0].tensor_layout() == output_layout && + swc->cost_list[0]->communication_without_parameter_ == 0.0) { + (void)index_computation.emplace_back(SizeToLong(i), swc->cost_list[0]->computation_cost_); + } + } + if (index_computation.empty()) { + MS_LOG(ERROR) << "There in no available strategy for zero communication cost for reshape: " << name(); + return -1; + } + if (index_computation.size() > 1) { + MS_LOG(INFO) << "There are multiple strategies available for reshape: " << name(); + } + std::sort( + index_computation.begin(), index_computation.end(), + [](const std::pair &a, const std::pair &b) { return a.second <= b.second; }); + return index_computation[0].first; +} + +int64_t ReshapeInfo::GetSWCIndexByOutputLayout(const TensorLayout &output_layout) { + std::vector> index_computation; + for (size_t i = 0; i < strategy_cost_.size(); ++i) { + const auto &swc = strategy_cost_[i]; + if (swc->outputs_ptr[0].tensor_layout() == output_layout) { + (void)index_computation.emplace_back(SizeToLong(i), swc->cost_list[0]->computation_cost_); + } + } + if (index_computation.empty()) { + MS_LOG(ERROR) << "There in no available strategy for zero communication cost for reshape: " << name(); + return -1; + } + if (index_computation.size() > 1) { + MS_LOG(INFO) << "There are multiple strategies available for reshape: " << name(); + } + std::sort( + index_computation.begin(), index_computation.end(), + [](const std::pair &a, const std::pair &b) { return a.second <= b.second; }); + return index_computation[0].first; +} + +int64_t ReshapeInfo::GetSWCIndexByInputLayoutWithZeroComm(const TensorLayout &input_layout) { + std::vector> index_computation; + for (size_t i = 0; i < strategy_cost_.size(); ++i) { + const auto &swc = strategy_cost_[i]; + if (swc->inputs_ptr[0].tensor_layout() == input_layout && + swc->cost_list[0]->communication_without_parameter_ == 0.0) { + (void)index_computation.emplace_back(SizeToLong(i), swc->cost_list[0]->computation_cost_); + } + } + if (index_computation.empty()) { + MS_LOG(ERROR) << "There in no available strategy for zero communication cost for reshape: " << name(); + return -1; + } + if (index_computation.size() > 1) { + MS_LOG(INFO) << "There are multiple strategies available for reshape: " << name(); + } + std::sort( + index_computation.begin(), index_computation.end(), + [](const std::pair &a, const std::pair &b) { return a.second <= b.second; }); + return index_computation[0].first; +} + +int64_t ReshapeInfo::GetSWCIndexByInputLayout(const TensorLayout &input_layout) { + std::vector> index_computation; + for (size_t i = 0; i < strategy_cost_.size(); ++i) { + const auto &swc = strategy_cost_[i]; + if (swc->inputs_ptr[0].tensor_layout() == input_layout) { + (void)index_computation.emplace_back(SizeToLong(i), swc->cost_list[0]->computation_cost_); + } + } + if (index_computation.empty()) { + MS_LOG(ERROR) << "There in no available strategy for zero communication cost for reshape: " << name(); + return -1; + } + if (index_computation.size() > 1) { + MS_LOG(INFO) << "There are multiple strategies available for reshape: " << name(); + } + std::sort( + index_computation.begin(), index_computation.end(), + [](const std::pair &a, const std::pair &b) { return a.second <= b.second; }); + return index_computation[0].first; +} + +bool ReshapeInfo::CheckStrategyConsistencyByOutputLayout(int64_t swc_index, const TensorLayout &output_layout) { + if (swc_index == -1 || swc_index >= SizeToLong(strategy_cost_.size())) { + MS_LOG(ERROR) << "The strategy_index: " << swc_index << " is out of range."; + return false; + } + const auto &swc = strategy_cost_[swc_index]; + return swc->outputs_ptr[0].tensor_layout() == output_layout; +} + +bool ReshapeInfo::CheckStrategyConsistencyByInputLayout(int64_t swc_index, const TensorLayout &input_layout) { + if (swc_index == -1 || swc_index >= SizeToLong(strategy_cost_.size())) { + MS_LOG(ERROR) << "The strategy_index: " << swc_index << " is out of range."; + return false; + } + const auto &swc = strategy_cost_[swc_index]; + return swc->inputs_ptr[0].tensor_layout() == input_layout; +} + +TensorLayout ReshapeInfo::GetInputLayoutBySWCIndex(int64_t swc_index) { + if (swc_index == -1 || swc_index >= SizeToLong(strategy_cost_.size())) { + MS_LOG(EXCEPTION) << "The strategy_index: " << swc_index << " is out of range."; + } + const auto &swc = strategy_cost_[swc_index]; + return std::move(swc->inputs_ptr[0].tensor_layout()); +} + +TensorLayout ReshapeInfo::GetOutputLayoutBySWCIndex(int64_t swc_index) { + if (swc_index == -1 || swc_index >= SizeToLong(strategy_cost_.size())) { + MS_LOG(EXCEPTION) << "The strategy_index: " << swc_index << " is out of range."; + } + const auto &swc = strategy_cost_[swc_index]; + return std::move(swc->outputs_ptr[0].tensor_layout()); +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.h index 4d5498dcaf7..4207f5aef62 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.h @@ -58,7 +58,7 @@ class ReshapeInfo : public OperatorInfo { void set_next_operator_name(const std::string &next_name) { next_operator_name_ = next_name; } void set_pre_operator_index(int64_t pre_index) { pre_operator_index_ = pre_index; } void set_next_operator_index(int64_t next_index) { next_operator_index_ = next_index; } - Status GenetateStrategyCosts(const std::vector> &pre_stra_costs, + Status GenerateStrategyCosts(const std::vector> &pre_stra_costs, const std::vector> &next_stra_costs, int64_t out_index, int64_t in_index, bool is_prev_param, bool is_next_reshape); Status GenerateStrategies(int64_t stage_id) override; @@ -69,6 +69,16 @@ class ReshapeInfo : public OperatorInfo { int64_t pre_operator_index() const { return pre_operator_index_; } int64_t next_operator_index() const { return next_operator_index_; } + int64_t GetSWCIndexByOutputLayoutWithZeroComm(const TensorLayout &); + int64_t GetSWCIndexByOutputLayout(const TensorLayout &); + int64_t GetSWCIndexByInputLayoutWithZeroComm(const TensorLayout &); + int64_t GetSWCIndexByInputLayout(const TensorLayout &); + bool CheckStrategyConsistencyByOutputLayout(int64_t, const TensorLayout &); + bool CheckStrategyConsistencyByInputLayout(int64_t, const TensorLayout &); + + TensorLayout GetInputLayoutBySWCIndex(int64_t); + TensorLayout GetOutputLayoutBySWCIndex(int64_t); + protected: Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index 4e26f6dd132..e6387a9f662 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -665,6 +665,14 @@ void CreateEdgeBetweenTwoOps(const OperatorInfoPtr &prev_op_info, const Operator (*edge_count)++; } +void CheckAndApplyApproximation() { + // If 'approximation' is enabled, the edges need to be checked have effective costs. + auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi(); + if (approximation) { + entire_costgraph->CheckApproximateCostGraphEdges(); + } +} + void ConstructCostGraphEdges(const std::vector &all_nodes) { // Step 2 MS_LOG(INFO) << "Constructing edges for cost graph begins."; @@ -736,11 +744,7 @@ void ConstructCostGraphEdges(const std::vector &all_nodes) { } MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name(); } - // If 'approximation' is enabled, the edges need to be checked have effective costs. - auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi(); - if (approximation) { - entire_costgraph->CheckApproximateCostGraphEdges(); - } + CheckAndApplyApproximation(); MS_LOG(INFO) << "Constructing edges for cost graph ends."; } @@ -913,7 +917,7 @@ void ReshapeCostCompute(const std::vector &all_nodes) { reshape_info->set_next_operator_index(in_index); } bool is_prev_param = pre_node->isa(); - if (reshape_info->GenetateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param, + if (reshape_info->GenerateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param, is_next_reshape) != SUCCESS) { MS_LOG(EXCEPTION) << "reshape generate strategy_costs failed!"; } diff --git a/tests/ut/python/parallel/test_reshape_shard_propagation.py b/tests/ut/python/parallel/test_reshape_shard_propagation.py new file mode 100644 index 00000000000..a6d90407be3 --- /dev/null +++ b/tests/ut/python/parallel/test_reshape_shard_propagation.py @@ -0,0 +1,336 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor, context +from mindspore.common import dtype as mstype +from mindspore.common import Parameter +from mindspore.common.api import _cell_graph_executor +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from tests.ut.python.ops.test_math_ops import VirtualLoss + +grad_all = C.GradOperation(get_all=True) + + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x): + predict = self.network(x) + return self.loss(predict) + + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, x): + return grad_all(self.network)(x) + +class NetWithLossTwoInput(nn.Cell): + def __init__(self, network): + super(NetWithLossTwoInput, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x, y): + predict = self.network(x, y) + return self.loss(predict) + +class GradWrapTwoInput(nn.Cell): + def __init__(self, network): + super(GradWrapTwoInput, self).__init__() + self.network = network + + def construct(self, x, y): + return grad_all(self.network)(x, y) + + +def compile_graph(net, device_num, x): + context.set_auto_parallel_context(device_num=device_num, global_rank=0, parallel_mode="auto_parallel", + sharding_propagation=True) + net.set_auto_parallel() + net.set_train() + _cell_graph_executor.compile(net, x) + + +def compile_graph_two_input(net, device_num, x, y): + context.set_auto_parallel_context(device_num=device_num, global_rank=0, parallel_mode="auto_parallel", + sharding_propagation=True) + net.set_auto_parallel() + net.set_train() + _cell_graph_executor.compile(net, x, y) + + +def test_reshape_reshape(): + """ + Feature: Sharding propagation for Reshape. + Description: ReLU->Reshape + Expectation: compile done without error. + """ + device_num = 8 + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.reshape = P.Reshape() + self.relu = P.ReLU().shard(((1, 1, 1, 1),)) + + def construct(self, x): + x = self.relu(x) + out = self.reshape(x, (64, 28)) + out = self.reshape(out, (64, 28, 1)) + return out + + x = Tensor(np.ones([device_num * 8, 28, 1, 1]), dtype=ms.float32) + net = GradWrap(NetWithLoss(Net())) + compile_graph(net, device_num, x) + + +def test_reshape_auto_1(): + """ + Feature: Sharding propagation for Reshape. + Description: ReLU->Reshape->MatMul + Expectation: compile done without error. + """ + device_num = 8 + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.relu = P.ReLU().shard(((1, 1, 1, 1),)) + self.reshape = P.Reshape() + self.matmul = P.MatMul().shard(((2, 1), (1, 4))) + self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight") + + def construct(self, x): + x = self.relu(x) + out = self.reshape(x, (64, 28)) + out = self.matmul(out, self.matmul_weight) + return out + + x = Tensor(np.ones([device_num * 8, 28, 1, 1]), dtype=ms.float32) + net = GradWrap(NetWithLoss(Net())) + compile_graph(net, device_num, x) + + +def test_reshape_auto_2(): + """ + Feature: Sharding propagation for Reshape. + Description: ReLU->Reshape->MatMul->Reshape->Add + Expectation: compile done without error. + """ + device_num = 8 + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.relu = P.ReLU() + self.relu2 = P.ReLU() + self.reshape = P.Reshape() + self.matmul = P.MatMul().shard(((2, 1), (1, 4))) + self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight") + self.add = P.Add().shard(((2, 4), (2, 4))) + self.add_weight = Parameter(Tensor(np.ones([128, 32]), dtype=ms.float32), name="weight1") + + def construct(self, x): + out = self.relu(x) + out = self.relu2(out) + out = self.reshape(out, (64, 28)) + out = self.matmul(out, self.matmul_weight) + out = self.reshape(out, (128, 32)) + out = self.add(out, self.add_weight) + return out + + x = Tensor(np.ones([device_num * 8, 28, 1, 1]), dtype=ms.float32) + net = GradWrap(NetWithLoss(Net())) + compile_graph(net, device_num, x) + + +def test_reshape_auto_3(): + """ + Feature: Sharding propagation for Reshape. + Description: Mul->Add->Cast->Reshape->Cast->ReduceMean + Expectation: compile done without error. + """ + device_num = 8 + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.gamma = Parameter(Tensor(np.ones([1024]), dtype=ms.float32), name="gamma") + self.beta = Parameter(Tensor(np.ones([1024]), dtype=ms.float32), name="beta") + self.add = P.TensorAdd().shard(((8, 1, 1), (1,))) + self.mul = P.Mul().shard(((8, 1, 1), (1,))) + self.mean = P.ReduceMean(keep_dims=True).shard(((8, 1),)) + self.reshape = P.Reshape() + self.dtype1 = mstype.float16 + self.dtype2 = mstype.float32 + + def construct(self, x): + out = self.add(self.mul(x, self.gamma), self.beta) + out = F.cast(out, self.dtype1) + out = self.reshape(out, (-1, 1024)) + out = F.cast(out, self.dtype2) + out = self.mean(out, -1) + return out + + x = Tensor(np.ones([2048, 30, 1024]), dtype=ms.float32) + net = GradWrap(NetWithLoss(Net())) + compile_graph(net, device_num, x) + + +def test_reshape_auto_4(): + """ + Feature: Sharding propagation for Reshape. + Description: Mul->Add->Cast->Reshape->Cast->ReduceMean + Expectation: compile done without error. + """ + device_num = 8 + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.gamma = Parameter(Tensor(np.ones([1024]), dtype=ms.float32), name="gamma") + self.beta = Parameter(Tensor(np.ones([1024]), dtype=ms.float32), name="beta") + self.add = P.TensorAdd().shard(((8, 1, 1), (1,))) + self.mul = P.Mul().shard(((8, 1, 1), (1,))) + self.mean = P.ReduceMean(keep_dims=True) + self.reshape = P.Reshape() + self.dtype1 = mstype.float16 + self.dtype2 = mstype.float32 + + def construct(self, x): + out = self.add(self.mul(x, self.gamma), self.beta) + out = F.cast(out, self.dtype1) + out = self.reshape(out, (-1, 1024)) + out = F.cast(out, self.dtype2) + out = self.mean(out, -1) + return out + + x = Tensor(np.ones([2048, 30, 1024]), dtype=ms.float32) + net = GradWrap(NetWithLoss(Net())) + compile_graph(net, device_num, x) + + +def test_reshape_auto_5(): + """ + Feature: Sharding propagation for Reshape. + Description: Mul->Add->Cast->Reshape->Cast->ReduceMean + Expectation: compile done without error. + """ + device_num = 8 + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.gamma = Parameter(Tensor(np.ones([1024]), dtype=ms.float32), name="gamma") + self.beta = Parameter(Tensor(np.ones([1024]), dtype=ms.float32), name="beta") + self.add = P.TensorAdd().shard(((8, 1, 1), (1,))) + self.mul = P.Mul() + self.mean = P.ReduceMean(keep_dims=True).shard(((2, 4),)) + self.reshape = P.Reshape() + self.dtype1 = mstype.float16 + self.dtype2 = mstype.float32 + + def construct(self, x): + out = self.add(self.mul(x, self.gamma), self.beta) + out = self.reshape(out, (-1, 1024)) + out = self.mean(out, -1) + return out + + x = Tensor(np.ones([2048, 30, 1024]), dtype=ms.float32) + net = GradWrap(NetWithLoss(Net())) + compile_graph(net, device_num, x) + + +def test_reshape_auto_6(): + """ + Feature: Sharding propagation for Reshape. + Description: Reshape->ReLU->Mul->Reshape->Add->Mul->Reshape->Add + Expectation: compile done without error. + """ + device_num = 8 + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.relu = P.ReLU() + self.mul = P.Mul().shard(((8, 1, 1), (8, 1, 1))) + self.reshape = P.Reshape() + self.reduce_sum = P.ReduceSum() + self.wide_w = Parameter(Tensor(np.ones([8, 1024*8, 64]), dtype=ms.float32), name="weight") + + def construct(self, x, y): + mask = self.reshape(y, (8, 1024*8, 1)) + w_id = self.relu(x) + wx = self.mul(w_id, mask) + wide_out = self.reshape(self.reduce_sum(wx, 1), (-1, 1)) + deep_id = x + self.wide_w + vx = self.mul(deep_id, mask) + deep_in = self.reshape(vx, (-1, 1024*8*64)) + out = wide_out + deep_in + return out + + x = Tensor(np.ones([8, 1024*device_num, 1]), dtype=ms.float32) + y = Tensor(np.ones([8, 1024*device_num]), dtype=ms.float32) + net = GradWrapTwoInput(NetWithLossTwoInput(Net())) + compile_graph_two_input(net, device_num, x, y) + + +def test_reshape_depend_reshape(): + """ + Feature: Sharding propagation for Reshape. + Description: Mul->ReLU->Reshape->Reshape->Add + Expectation: compile with error. + """ + device_num = 8 + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.reshape1 = P.Reshape() + self.reshape2 = P.Reshape() + self.relu = P.ReLU() + self.depend = P.Depend() + self.mul = P.Mul().shard(((2, 4), (2, 4))) + self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight") + self.add = P.Add().shard(((4, 2), (4, 2))) + + def construct(self, x, y): + out1 = self.mul(x, self.mul_weight) + y = self.relu(y) + out2 = self.reshape1(y, (96, 32, 4)) + out3 = self.depend(out2, out1) + out3 = self.reshape2(out3, (128, 96)) + out = out1 + out3 + return out + + class NetWithLoss1(nn.Cell): + def __init__(self, network): + super(NetWithLoss1, self).__init__() + self.mean = P.ReduceMean(keep_dims=False) + self.network = network + + def construct(self, x, y): + predict = self.network(x, y) + return self.mean(predict, ()) + + x = Tensor(np.ones([128, 96]), dtype=ms.float32) + y = Tensor(np.ones([256, 48]), dtype=ms.float32) + net = GradWrapTwoInput(NetWithLoss1(Net())) + with pytest.raises(RuntimeError): + compile_graph_two_input(net, device_num, x, y)