add a flag to control whether overwrite the right-node in triangle elimination of DP algorithm

This commit is contained in:
Xiaoda Zhang 2020-09-17 16:07:54 +08:00
parent 0e60ed8927
commit 970490a6f0
7 changed files with 119 additions and 6 deletions

View File

@ -246,14 +246,15 @@ struct SourceEliminationDecision : public Decision {
*/
struct TriangleEliminationDecision : public Decision {
TriangleEliminationDecision(StrategyPtr elimi_stra, CostPtr elimi_op_cost, CostPtr l_edge_cost, CostPtr r_edge_cost,
StrategyPtr left_stra, CostPtr l_node_cost, StrategyPtr right_stra)
StrategyPtr left_stra, CostPtr l_node_cost, StrategyPtr right_stra, CostPtr r_node_cost)
: eliminated_op_strategy_(std::move(elimi_stra)),
eliminated_op_cost_(std::move(elimi_op_cost)),
left_edge_cost_(std::move(l_edge_cost)),
right_edge_cost_(std::move(r_edge_cost)),
left_node_strategy_(std::move(left_stra)),
left_node_cost_(std::move(l_node_cost)),
right_node_strategy_(std::move(right_stra)) {
right_node_strategy_(std::move(right_stra)),
right_node_cost_(std::move(r_node_cost)) {
type_ = DecisionType::TRIANGLE_ELIMINATION;
}
@ -264,6 +265,7 @@ struct TriangleEliminationDecision : public Decision {
StrategyPtr left_node_strategy_;
CostPtr left_node_cost_;
StrategyPtr right_node_strategy_;
CostPtr right_node_cost_;
MS_DECLARE_PARENT(TriangleEliminationDecision, Decision);
};

View File

@ -199,9 +199,16 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
eliminated_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_);
left_edge->set_selected_cost(decision->left_edge_cost_);
right_edge->set_selected_cost(decision->right_edge_cost_);
// Since Triangle is eliminated into 'left_node', only 'left_node' is needed to recover the strategy.
// 'left_node' recovers the strategy.
left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_);
right_node->CheckSelectedStrategy(decision->right_node_strategy_);
if (TRIANGLE_STRATEGY_OVERWRITE) {
// 'right_node' recovers the strategy.
MS_LOG(INFO) << "Overwrite the right-node: " << right_node->name() << " in recovering triangle elimination.";
right_node->SetSelectedStrategyAndCost(decision->right_node_strategy_, decision->right_node_cost_);
} else {
// In this case, 'right_node' is not overwriten strategy, and it checks strategy consistency.
right_node->CheckSelectedStrategy(decision->right_node_strategy_);
}
MS_LOG(INFO) << "Recover triangleElimination succeeded.";
} else if ((*rit)->isa<StarElimination>()) {
auto elimination = (*rit)->cast<StarEliminationPtr>();

View File

@ -40,6 +40,7 @@ bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES;
bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;
bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS;
int32_t RUN_PHASE = DEFAULT_RUN_PHASE;
bool TRIANGLE_STRATEGY_OVERWRITE = DEFAULT_TRIANGLE_STRATEGY_OVERWRITE;
void CostGraph::SetDeviceMemoryAndCostParameter() {
MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());
@ -154,6 +155,14 @@ void CostGraph::SetDeviceMemoryAndCostParameter() {
MS_LOG(INFO) << "multi_subgraphs: false.";
}
auto overwrite = CostModelContext::GetInstance()->triangle_strategy_overwrite();
TRIANGLE_STRATEGY_OVERWRITE = overwrite;
if (TRIANGLE_STRATEGY_OVERWRITE) {
MS_LOG(INFO) << "triangle_strategy_overwrite: true.";
} else {
MS_LOG(INFO) << "triangle_strategy_overwrite: false.";
}
// RUN_PHASE
auto phase = CostModelContext::GetInstance()->run_phase();
if (phase != 0 && phase != 1) {
@ -1294,8 +1303,17 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra,
elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ +
left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_;
auto decision = std::make_shared<TriangleEliminationDecision>(
elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost, left_op_stra, left_node_cost, right_op_stra);
if (TRIANGLE_STRATEGY_OVERWRITE) {
new_computation += right_op_cost->computation_cost_;
new_memory += right_op_cost->memory_with_reuse_;
new_commu_cost += right_op_cost->communication_cost_;
new_commu_forward += right_op_cost->communication_forward_;
new_commu_without += right_op_cost->communication_without_parameter_;
}
auto decision =
std::make_shared<TriangleEliminationDecision>(elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost,
left_op_stra, left_node_cost, right_op_stra, right_op_cost);
auto new_cost = std::make_shared<Cost>(new_computation, new_commu_cost, decision);
new_cost->communication_without_parameter_ = new_commu_without;
new_cost->communication_with_partial_para_ =

View File

@ -46,6 +46,7 @@ extern bool FULLY_USE_DEVICES;
extern bool ELEMENTWISE_OP_STRA_FOLLOW;
extern bool MULTI_SUBGRAPHS;
extern int32_t RUN_PHASE;
extern bool TRIANGLE_STRATEGY_OVERWRITE;
class CostGraph {
// 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have

View File

@ -64,6 +64,7 @@ void CostModelContext::ResetAlgoParameters() {
tensor_slice_alignment_size_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE;
fully_use_device_ = DEFAULT_FULLY_USE_DEVICES;
elementwise_stra_follow_ = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;
triangle_strategy_overwrite_ = DEFAULT_TRIANGLE_STRATEGY_OVERWRITE;
}
void CostModelContext::set_costmodel_context_for_device(const std::string &device_target) {
@ -133,6 +134,8 @@ void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) {
elementwise_stra_follow_ = elementwise_follow;
}
void CostModelContext::set_triangle_strategy_overwrite(bool overwrite) { triangle_strategy_overwrite_ = overwrite; }
void CostModelContext::set_run_phase(int32_t phase) { run_phase_ = phase; }
struct CostRegister {

View File

@ -44,6 +44,8 @@ namespace parallel {
#define DEFAULT_RUN_PHASE 0
#define TRAINING_PHASE 0
#define INFERENCE_PHASE 1
#define DEFAULT_TRIANGLE_STRATEGY_OVERWRITE true;
class CostModelContext {
public:
~CostModelContext() = default;
@ -133,6 +135,9 @@ class CostModelContext {
void set_elementwise_stra_follow(bool);
bool elementwise_stra_follow() const { return elementwise_stra_follow_; }
void set_triangle_strategy_overwrite(bool);
bool triangle_strategy_overwrite() const { return triangle_strategy_overwrite_; }
void set_run_phase(int32_t);
int32_t run_phase() const { return run_phase_; }
@ -167,6 +172,10 @@ class CostModelContext {
// MULTI_SUBGRAPHS
bool is_multi_subgraphs_;
// In the recovery phase of DP algorithm, when encountering triangle structure,
// whether overwrite the right-node strategy
bool triangle_strategy_overwrite_;
int32_t run_phase_; // 0: 'training', 1: 'inference'
int32_t costmodel_allreduce_fusion_algorithm_;

View File

@ -0,0 +1,73 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore.common.api import _executor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.parallel._utils import _reset_op_id as reset_op_id
from mindspore import context, Tensor, Parameter
from tests.ut.python.ops.test_math_ops import VirtualLoss
grad_all = C.GradOperation(get_all=True)
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x):
predict = self.network(x)
return self.loss(predict)
class GradWarp(nn.Cell):
def __init__(self, network):
super(GradWarp, self).__init__()
self.network = network
def construct(self, x):
return grad_all(self.network)(x)
def test_triangle_strategy_consistency():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.mul1 = P.Mul().shard(((2, 4), (2, 4)))
self.mul2 = P.Mul()
self.ba1 = P.BiasAdd()
self.weight = Parameter(Tensor(np.ones([128, 1000]), dtype=ms.float32), name="weight")
self.bias = Parameter(Tensor(np.ones([1000]), dtype=ms.float32), name="bias")
self.add = P.TensorAdd().shard(((1, 8), (1, 8)))
self.relu = P.ReLU()
def construct(self, x):
out = self.mul1(x, self.weight)
mul_out = self.mul2(out, self.weight)
ba_out = self.ba1(out, self.bias)
ta_out = self.add(mul_out, ba_out)
out = self.relu(ta_out)
return out
size = 8
context.set_auto_parallel_context(device_num=size, global_rank=0)
x = Tensor(np.ones([128, 1000]), dtype=ms.float32)
net = NetWithLoss(Net())
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
reset_op_id()
_executor.compile(net, x, phase='train')