changing the succive edges order in GetAliveSuccEdges() so that Triangle and Star Elimination can be merged into particular node; adding some check information

This commit is contained in:
Xiaoda Zhang 2020-05-18 15:31:05 +08:00
parent d9c74e0acd
commit 9f4b8a3cd1
8 changed files with 75 additions and 12 deletions

View File

@ -211,13 +211,14 @@ struct ContractEliminationDecision : public Decision {
*/
struct TriangleEliminationDecision : public Decision {
TriangleEliminationDecision(StrategyPtr elimi_stra, CostPtr elimi_op_cost, CostPtr l_edge_cost, CostPtr r_edge_cost,
StrategyPtr left_stra, CostPtr l_node_cost)
StrategyPtr left_stra, CostPtr l_node_cost, StrategyPtr right_stra)
: eliminated_op_strategy_(std::move(elimi_stra)),
eliminated_op_cost_(std::move(elimi_op_cost)),
left_edge_cost_(std::move(l_edge_cost)),
right_edge_cost_(std::move(r_edge_cost)),
left_node_strategy_(std::move(left_stra)),
left_node_cost_(std::move(l_node_cost)) {
left_node_cost_(std::move(l_node_cost)),
right_node_strategy_(std::move(right_stra)) {
type_ = DecisionType::TRIANGLE_ELIMINATION;
}
@ -227,6 +228,7 @@ struct TriangleEliminationDecision : public Decision {
CostPtr right_edge_cost_;
StrategyPtr left_node_strategy_;
CostPtr left_node_cost_;
StrategyPtr right_node_strategy_;
MS_DECLARE_PARENT(TriangleEliminationDecision, Decision);
};

View File

@ -85,7 +85,9 @@ Status GetStrategy(const CostGraphPtr &graph) {
right_edge = tmp;
}
auto left_node_cpy = graph->EliminationTriangle(eliminated_node, l_r_edge);
auto elimi = std::make_shared<TriangleElimination>(eliminated_node, left_edge, left_node_cpy, right_edge);
auto right_node = l_r_edge->next_operator();
auto elimi =
std::make_shared<TriangleElimination>(eliminated_node, left_edge, left_node_cpy, right_edge, right_node);
eliminations.emplace_back(std::move(elimi));
}
auto star_center = graph->CheckStarElimination();
@ -181,6 +183,7 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
auto left_edge = elimination->left_edge_;
auto eliminated_node = elimination->eliminated_node_;
auto right_edge = elimination->right_edge_;
auto right_node = elimination->right_node_;
auto decision = left_node->selected_cost()->decision_ptr_->cast<TriangleEliminationDecisionPtr>();
eliminated_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_);
@ -188,6 +191,7 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
right_edge->set_selected_cost(decision->right_edge_cost_);
// Since Triangle is eliminated into 'left_node', only 'left_node' is needed to recover the strategy.
left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_);
right_node->CheckSelectedStrategy(decision->right_node_strategy_);
MS_LOG(INFO) << "Recover triangleElimination succeeded.";
} else if ((*rit)->isa<StarElimination>()) {
auto elimination = (*rit)->cast<StarEliminationPtr>();
@ -206,6 +210,9 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
MS_EXCEPTION_IF_NULL(decision->succ_ops_cost_list_[0]);
// Since Star is eliminated into 'succ_nodes[0]', only 'succ_nodes[0]' is needed to recover the strategy.
succ_nodes[0]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[0], decision->succ_ops_cost_list_[0]);
for (size_t k = 1; k < succ_nodes.size(); ++k) {
succ_nodes[k]->CheckSelectedStrategy(decision->succ_ops_stra_list_[k]);
}
MS_LOG(INFO) << "Recover starElimination succeeded.";
} else {
MS_LOG(ERROR) << "Unknown Elimination type.";

View File

@ -102,17 +102,20 @@ struct ContractElimination : public Elimination {
// Triangle Elimination
struct TriangleElimination : public Elimination {
TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge)
TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge,
OperatorInfoPtr r_node)
: Elimination(nullptr, Elimination::EliminationType::TRIANGLE),
eliminated_node_(std::move(elim_node)),
left_edge_(std::move(l_edge)),
left_node_(std::move(l_node)),
right_edge_(std::move(r_edge)) {}
right_edge_(std::move(r_edge)),
right_node_(std::move(r_node)) {}
OperatorInfoPtr eliminated_node_;
EdgePtr left_edge_;
OperatorInfoPtr left_node_;
EdgePtr right_edge_;
OperatorInfoPtr right_node_;
MS_DECLARE_PARENT(TriangleElimination, Elimination);
};

View File

@ -1111,8 +1111,8 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra,
elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ +
left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_;
auto decision = std::make_shared<TriangleEliminationDecision>(elimi_op_stra, elimi_op_cost, left_edge_cost,
right_edge_cost, left_op_stra, left_node_cost);
auto decision = std::make_shared<TriangleEliminationDecision>(
elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost, left_op_stra, left_node_cost, right_op_stra);
auto new_cost = std::make_shared<Cost>(new_computation, new_commu_cost, decision);
new_cost->communication_without_parameter_ = new_commu_without;
new_cost->communication_with_partial_para_ =

View File

@ -546,10 +546,14 @@ std::vector<std::shared_ptr<Edge>> OperatorInfo::GetAliveSuccEdges() {
for (auto &edge : succ_edges_) {
if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) != std::string::npos)) {
ret.push_back(edge);
} else if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(CAST) != std::string::npos)) {
// CAST is ordered in front of L2NORMALIZE
ret.push_back(edge);
}
}
for (auto &edge : succ_edges_) {
if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) == std::string::npos)) {
if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) == std::string::npos) &&
(edge->next_operator()->name().find(CAST) == std::string::npos)) {
ret.push_back(edge);
}
}
@ -1279,10 +1283,18 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra
CheckGlobalDeviceManager();
auto total_device_num = g_device_manager->GetDeviceListByStageId(stra->GetInputStage()).size();
if (IntToSize(stra->GetInputDim()[0][0]) == total_device_num) {
cost->computation_cost_ -= 1.0;
cost->communication_cost_ -= 1.0;
cost->communication_with_partial_para_ -= 1.0;
cost->communication_without_parameter_ -= 1.0;
if (cost->computation_cost_ > 1.0) {
cost->computation_cost_ -= 1.0;
}
if (cost->communication_cost_ > 1.0) {
cost->communication_cost_ -= 1.0;
}
if (cost->communication_with_partial_para_ > 1.0) {
cost->communication_with_partial_para_ -= 1.0;
}
if (cost->communication_without_parameter_ > 1.0) {
cost->communication_without_parameter_ -= 1.0;
}
}
}
}
@ -1290,5 +1302,15 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra
double OperatorInfo::GetForwardMemoryCostFromCNode() {
return operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0);
}
void OperatorInfo::CheckSelectedStrategy(const StrategyPtr &s_strategy) {
MS_EXCEPTION_IF_NULL(s_strategy);
if (!s_strategy->IsEqual(selected_strategy_)) {
MS_LOG(INFO) << name() << "'s strategy may cause suboptimal, the determined strategy:";
PrintStrategy(selected_strategy_);
MS_LOG(INFO) << "The minimal strategy:";
PrintStrategy(s_strategy);
}
}
} // namespace parallel
} // namespace mindspore

View File

@ -138,6 +138,7 @@ class OperatorInfo {
}
StrategyPtr selected_strategy() const { return selected_strategy_; }
CostPtr selected_cost() const { return selected_cost_; }
void CheckSelectedStrategy(const StrategyPtr &);
Status InitSelectedStrategy(const StrategyPtr &s_strategy) { return Init(s_strategy); }
void set_input_value(const std::vector<ValuePtr> &input_value) { input_value_ = input_value; }
const std::vector<ValuePtr> &input_value() const { return input_value_; }

View File

@ -48,6 +48,16 @@ class Strategy {
}
void ResetInputs(const std::vector<Dimensions> &input) { inputs_ = input; }
bool IsEqual(const StrategyPtr &another_stra) {
if (another_stra == nullptr) {
return false;
}
if ((stage_ != another_stra->GetInputStage()) || (inputs_ != another_stra->GetInputDim())) {
return false;
}
return true;
}
private:
const int32_t stage_;

View File

@ -64,5 +64,23 @@ TEST_F(TestStrategy, GetInputDim) {
ASSERT_EQ(inputs, inputs_test);
}
TEST_F(TestStrategy, IsEqual) {
int32_t stage1 = 0, stage2 = 0, stage3 = 1, stage4 = 0;
std::vector<int32_t> dimension1 = {8, 1};
std::vector<int32_t> dimension2 = {1, 8};
std::vector<std::vector<int32_t>> inputs1 = {dimension1};
std::vector<std::vector<int32_t>> inputs2 = {dimension1};
std::vector<std::vector<int32_t>> inputs3 = {dimension2};
std::vector<std::vector<int32_t>> inputs4 = {dimension1, dimension2};
StrategyPtr stra1 = std::make_shared<Strategy>(stage1, inputs1);
StrategyPtr stra2 = std::make_shared<Strategy>(stage2, inputs2);
StrategyPtr stra3 = std::make_shared<Strategy>(stage3, inputs3);
StrategyPtr stra4 = std::make_shared<Strategy>(stage4, inputs4);
ASSERT_EQ(stra1->IsEqual(stra2), true);
ASSERT_EQ(stra1->IsEqual(stra3), false);
ASSERT_EQ(stra1->IsEqual(stra4), false);
}
} // namespace parallel
} // namespace mindspore