forked from mindspore-Ecosystem/mindspore
changing the succive edges order in GetAliveSuccEdges() so that Triangle and Star Elimination can be merged into particular node; adding some check information
This commit is contained in:
parent
d9c74e0acd
commit
9f4b8a3cd1
|
@ -211,13 +211,14 @@ struct ContractEliminationDecision : public Decision {
|
|||
*/
|
||||
struct TriangleEliminationDecision : public Decision {
|
||||
TriangleEliminationDecision(StrategyPtr elimi_stra, CostPtr elimi_op_cost, CostPtr l_edge_cost, CostPtr r_edge_cost,
|
||||
StrategyPtr left_stra, CostPtr l_node_cost)
|
||||
StrategyPtr left_stra, CostPtr l_node_cost, StrategyPtr right_stra)
|
||||
: eliminated_op_strategy_(std::move(elimi_stra)),
|
||||
eliminated_op_cost_(std::move(elimi_op_cost)),
|
||||
left_edge_cost_(std::move(l_edge_cost)),
|
||||
right_edge_cost_(std::move(r_edge_cost)),
|
||||
left_node_strategy_(std::move(left_stra)),
|
||||
left_node_cost_(std::move(l_node_cost)) {
|
||||
left_node_cost_(std::move(l_node_cost)),
|
||||
right_node_strategy_(std::move(right_stra)) {
|
||||
type_ = DecisionType::TRIANGLE_ELIMINATION;
|
||||
}
|
||||
|
||||
|
@ -227,6 +228,7 @@ struct TriangleEliminationDecision : public Decision {
|
|||
CostPtr right_edge_cost_;
|
||||
StrategyPtr left_node_strategy_;
|
||||
CostPtr left_node_cost_;
|
||||
StrategyPtr right_node_strategy_;
|
||||
MS_DECLARE_PARENT(TriangleEliminationDecision, Decision);
|
||||
};
|
||||
|
||||
|
|
|
@ -85,7 +85,9 @@ Status GetStrategy(const CostGraphPtr &graph) {
|
|||
right_edge = tmp;
|
||||
}
|
||||
auto left_node_cpy = graph->EliminationTriangle(eliminated_node, l_r_edge);
|
||||
auto elimi = std::make_shared<TriangleElimination>(eliminated_node, left_edge, left_node_cpy, right_edge);
|
||||
auto right_node = l_r_edge->next_operator();
|
||||
auto elimi =
|
||||
std::make_shared<TriangleElimination>(eliminated_node, left_edge, left_node_cpy, right_edge, right_node);
|
||||
eliminations.emplace_back(std::move(elimi));
|
||||
}
|
||||
auto star_center = graph->CheckStarElimination();
|
||||
|
@ -181,6 +183,7 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
|
|||
auto left_edge = elimination->left_edge_;
|
||||
auto eliminated_node = elimination->eliminated_node_;
|
||||
auto right_edge = elimination->right_edge_;
|
||||
auto right_node = elimination->right_node_;
|
||||
auto decision = left_node->selected_cost()->decision_ptr_->cast<TriangleEliminationDecisionPtr>();
|
||||
|
||||
eliminated_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_);
|
||||
|
@ -188,6 +191,7 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
|
|||
right_edge->set_selected_cost(decision->right_edge_cost_);
|
||||
// Since Triangle is eliminated into 'left_node', only 'left_node' is needed to recover the strategy.
|
||||
left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_);
|
||||
right_node->CheckSelectedStrategy(decision->right_node_strategy_);
|
||||
MS_LOG(INFO) << "Recover triangleElimination succeeded.";
|
||||
} else if ((*rit)->isa<StarElimination>()) {
|
||||
auto elimination = (*rit)->cast<StarEliminationPtr>();
|
||||
|
@ -206,6 +210,9 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
|
|||
MS_EXCEPTION_IF_NULL(decision->succ_ops_cost_list_[0]);
|
||||
// Since Star is eliminated into 'succ_nodes[0]', only 'succ_nodes[0]' is needed to recover the strategy.
|
||||
succ_nodes[0]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[0], decision->succ_ops_cost_list_[0]);
|
||||
for (size_t k = 1; k < succ_nodes.size(); ++k) {
|
||||
succ_nodes[k]->CheckSelectedStrategy(decision->succ_ops_stra_list_[k]);
|
||||
}
|
||||
MS_LOG(INFO) << "Recover starElimination succeeded.";
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unknown Elimination type.";
|
||||
|
|
|
@ -102,17 +102,20 @@ struct ContractElimination : public Elimination {
|
|||
|
||||
// Triangle Elimination
|
||||
struct TriangleElimination : public Elimination {
|
||||
TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge)
|
||||
TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge,
|
||||
OperatorInfoPtr r_node)
|
||||
: Elimination(nullptr, Elimination::EliminationType::TRIANGLE),
|
||||
eliminated_node_(std::move(elim_node)),
|
||||
left_edge_(std::move(l_edge)),
|
||||
left_node_(std::move(l_node)),
|
||||
right_edge_(std::move(r_edge)) {}
|
||||
right_edge_(std::move(r_edge)),
|
||||
right_node_(std::move(r_node)) {}
|
||||
|
||||
OperatorInfoPtr eliminated_node_;
|
||||
EdgePtr left_edge_;
|
||||
OperatorInfoPtr left_node_;
|
||||
EdgePtr right_edge_;
|
||||
OperatorInfoPtr right_node_;
|
||||
MS_DECLARE_PARENT(TriangleElimination, Elimination);
|
||||
};
|
||||
|
||||
|
|
|
@ -1111,8 +1111,8 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra,
|
|||
elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ +
|
||||
left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_;
|
||||
|
||||
auto decision = std::make_shared<TriangleEliminationDecision>(elimi_op_stra, elimi_op_cost, left_edge_cost,
|
||||
right_edge_cost, left_op_stra, left_node_cost);
|
||||
auto decision = std::make_shared<TriangleEliminationDecision>(
|
||||
elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost, left_op_stra, left_node_cost, right_op_stra);
|
||||
auto new_cost = std::make_shared<Cost>(new_computation, new_commu_cost, decision);
|
||||
new_cost->communication_without_parameter_ = new_commu_without;
|
||||
new_cost->communication_with_partial_para_ =
|
||||
|
|
|
@ -546,10 +546,14 @@ std::vector<std::shared_ptr<Edge>> OperatorInfo::GetAliveSuccEdges() {
|
|||
for (auto &edge : succ_edges_) {
|
||||
if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) != std::string::npos)) {
|
||||
ret.push_back(edge);
|
||||
} else if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(CAST) != std::string::npos)) {
|
||||
// CAST is ordered in front of L2NORMALIZE
|
||||
ret.push_back(edge);
|
||||
}
|
||||
}
|
||||
for (auto &edge : succ_edges_) {
|
||||
if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) == std::string::npos)) {
|
||||
if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) == std::string::npos) &&
|
||||
(edge->next_operator()->name().find(CAST) == std::string::npos)) {
|
||||
ret.push_back(edge);
|
||||
}
|
||||
}
|
||||
|
@ -1279,10 +1283,18 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra
|
|||
CheckGlobalDeviceManager();
|
||||
auto total_device_num = g_device_manager->GetDeviceListByStageId(stra->GetInputStage()).size();
|
||||
if (IntToSize(stra->GetInputDim()[0][0]) == total_device_num) {
|
||||
cost->computation_cost_ -= 1.0;
|
||||
cost->communication_cost_ -= 1.0;
|
||||
cost->communication_with_partial_para_ -= 1.0;
|
||||
cost->communication_without_parameter_ -= 1.0;
|
||||
if (cost->computation_cost_ > 1.0) {
|
||||
cost->computation_cost_ -= 1.0;
|
||||
}
|
||||
if (cost->communication_cost_ > 1.0) {
|
||||
cost->communication_cost_ -= 1.0;
|
||||
}
|
||||
if (cost->communication_with_partial_para_ > 1.0) {
|
||||
cost->communication_with_partial_para_ -= 1.0;
|
||||
}
|
||||
if (cost->communication_without_parameter_ > 1.0) {
|
||||
cost->communication_without_parameter_ -= 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1290,5 +1302,15 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra
|
|||
double OperatorInfo::GetForwardMemoryCostFromCNode() {
|
||||
return operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0);
|
||||
}
|
||||
|
||||
void OperatorInfo::CheckSelectedStrategy(const StrategyPtr &s_strategy) {
|
||||
MS_EXCEPTION_IF_NULL(s_strategy);
|
||||
if (!s_strategy->IsEqual(selected_strategy_)) {
|
||||
MS_LOG(INFO) << name() << "'s strategy may cause suboptimal, the determined strategy:";
|
||||
PrintStrategy(selected_strategy_);
|
||||
MS_LOG(INFO) << "The minimal strategy:";
|
||||
PrintStrategy(s_strategy);
|
||||
}
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -138,6 +138,7 @@ class OperatorInfo {
|
|||
}
|
||||
StrategyPtr selected_strategy() const { return selected_strategy_; }
|
||||
CostPtr selected_cost() const { return selected_cost_; }
|
||||
void CheckSelectedStrategy(const StrategyPtr &);
|
||||
Status InitSelectedStrategy(const StrategyPtr &s_strategy) { return Init(s_strategy); }
|
||||
void set_input_value(const std::vector<ValuePtr> &input_value) { input_value_ = input_value; }
|
||||
const std::vector<ValuePtr> &input_value() const { return input_value_; }
|
||||
|
|
|
@ -48,6 +48,16 @@ class Strategy {
|
|||
}
|
||||
void ResetInputs(const std::vector<Dimensions> &input) { inputs_ = input; }
|
||||
|
||||
bool IsEqual(const StrategyPtr &another_stra) {
|
||||
if (another_stra == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if ((stage_ != another_stra->GetInputStage()) || (inputs_ != another_stra->GetInputDim())) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
const int32_t stage_;
|
||||
|
||||
|
|
|
@ -64,5 +64,23 @@ TEST_F(TestStrategy, GetInputDim) {
|
|||
ASSERT_EQ(inputs, inputs_test);
|
||||
}
|
||||
|
||||
TEST_F(TestStrategy, IsEqual) {
|
||||
int32_t stage1 = 0, stage2 = 0, stage3 = 1, stage4 = 0;
|
||||
std::vector<int32_t> dimension1 = {8, 1};
|
||||
std::vector<int32_t> dimension2 = {1, 8};
|
||||
std::vector<std::vector<int32_t>> inputs1 = {dimension1};
|
||||
std::vector<std::vector<int32_t>> inputs2 = {dimension1};
|
||||
std::vector<std::vector<int32_t>> inputs3 = {dimension2};
|
||||
std::vector<std::vector<int32_t>> inputs4 = {dimension1, dimension2};
|
||||
|
||||
StrategyPtr stra1 = std::make_shared<Strategy>(stage1, inputs1);
|
||||
StrategyPtr stra2 = std::make_shared<Strategy>(stage2, inputs2);
|
||||
StrategyPtr stra3 = std::make_shared<Strategy>(stage3, inputs3);
|
||||
StrategyPtr stra4 = std::make_shared<Strategy>(stage4, inputs4);
|
||||
|
||||
ASSERT_EQ(stra1->IsEqual(stra2), true);
|
||||
ASSERT_EQ(stra1->IsEqual(stra3), false);
|
||||
ASSERT_EQ(stra1->IsEqual(stra4), false);
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue