forked from mindspore-Ecosystem/mindspore
add a flag to control whether overwrite the right-node in triangle elimination of DP algorithm
This commit is contained in:
parent
0e60ed8927
commit
970490a6f0
|
@ -246,14 +246,15 @@ struct SourceEliminationDecision : public Decision {
|
|||
*/
|
||||
struct TriangleEliminationDecision : public Decision {
|
||||
TriangleEliminationDecision(StrategyPtr elimi_stra, CostPtr elimi_op_cost, CostPtr l_edge_cost, CostPtr r_edge_cost,
|
||||
StrategyPtr left_stra, CostPtr l_node_cost, StrategyPtr right_stra)
|
||||
StrategyPtr left_stra, CostPtr l_node_cost, StrategyPtr right_stra, CostPtr r_node_cost)
|
||||
: eliminated_op_strategy_(std::move(elimi_stra)),
|
||||
eliminated_op_cost_(std::move(elimi_op_cost)),
|
||||
left_edge_cost_(std::move(l_edge_cost)),
|
||||
right_edge_cost_(std::move(r_edge_cost)),
|
||||
left_node_strategy_(std::move(left_stra)),
|
||||
left_node_cost_(std::move(l_node_cost)),
|
||||
right_node_strategy_(std::move(right_stra)) {
|
||||
right_node_strategy_(std::move(right_stra)),
|
||||
right_node_cost_(std::move(r_node_cost)) {
|
||||
type_ = DecisionType::TRIANGLE_ELIMINATION;
|
||||
}
|
||||
|
||||
|
@ -264,6 +265,7 @@ struct TriangleEliminationDecision : public Decision {
|
|||
StrategyPtr left_node_strategy_;
|
||||
CostPtr left_node_cost_;
|
||||
StrategyPtr right_node_strategy_;
|
||||
CostPtr right_node_cost_;
|
||||
MS_DECLARE_PARENT(TriangleEliminationDecision, Decision);
|
||||
};
|
||||
|
||||
|
|
|
@ -199,9 +199,16 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
|
|||
eliminated_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_);
|
||||
left_edge->set_selected_cost(decision->left_edge_cost_);
|
||||
right_edge->set_selected_cost(decision->right_edge_cost_);
|
||||
// Since Triangle is eliminated into 'left_node', only 'left_node' is needed to recover the strategy.
|
||||
// 'left_node' recovers the strategy.
|
||||
left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_);
|
||||
right_node->CheckSelectedStrategy(decision->right_node_strategy_);
|
||||
if (TRIANGLE_STRATEGY_OVERWRITE) {
|
||||
// 'right_node' recovers the strategy.
|
||||
MS_LOG(INFO) << "Overwrite the right-node: " << right_node->name() << " in recovering triangle elimination.";
|
||||
right_node->SetSelectedStrategyAndCost(decision->right_node_strategy_, decision->right_node_cost_);
|
||||
} else {
|
||||
// In this case, 'right_node' is not overwriten strategy, and it checks strategy consistency.
|
||||
right_node->CheckSelectedStrategy(decision->right_node_strategy_);
|
||||
}
|
||||
MS_LOG(INFO) << "Recover triangleElimination succeeded.";
|
||||
} else if ((*rit)->isa<StarElimination>()) {
|
||||
auto elimination = (*rit)->cast<StarEliminationPtr>();
|
||||
|
|
|
@ -40,6 +40,7 @@ bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES;
|
|||
bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;
|
||||
bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS;
|
||||
int32_t RUN_PHASE = DEFAULT_RUN_PHASE;
|
||||
bool TRIANGLE_STRATEGY_OVERWRITE = DEFAULT_TRIANGLE_STRATEGY_OVERWRITE;
|
||||
|
||||
void CostGraph::SetDeviceMemoryAndCostParameter() {
|
||||
MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());
|
||||
|
@ -154,6 +155,14 @@ void CostGraph::SetDeviceMemoryAndCostParameter() {
|
|||
MS_LOG(INFO) << "multi_subgraphs: false.";
|
||||
}
|
||||
|
||||
auto overwrite = CostModelContext::GetInstance()->triangle_strategy_overwrite();
|
||||
TRIANGLE_STRATEGY_OVERWRITE = overwrite;
|
||||
if (TRIANGLE_STRATEGY_OVERWRITE) {
|
||||
MS_LOG(INFO) << "triangle_strategy_overwrite: true.";
|
||||
} else {
|
||||
MS_LOG(INFO) << "triangle_strategy_overwrite: false.";
|
||||
}
|
||||
|
||||
// RUN_PHASE
|
||||
auto phase = CostModelContext::GetInstance()->run_phase();
|
||||
if (phase != 0 && phase != 1) {
|
||||
|
@ -1294,8 +1303,17 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra,
|
|||
elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ +
|
||||
left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_;
|
||||
|
||||
auto decision = std::make_shared<TriangleEliminationDecision>(
|
||||
elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost, left_op_stra, left_node_cost, right_op_stra);
|
||||
if (TRIANGLE_STRATEGY_OVERWRITE) {
|
||||
new_computation += right_op_cost->computation_cost_;
|
||||
new_memory += right_op_cost->memory_with_reuse_;
|
||||
new_commu_cost += right_op_cost->communication_cost_;
|
||||
new_commu_forward += right_op_cost->communication_forward_;
|
||||
new_commu_without += right_op_cost->communication_without_parameter_;
|
||||
}
|
||||
|
||||
auto decision =
|
||||
std::make_shared<TriangleEliminationDecision>(elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost,
|
||||
left_op_stra, left_node_cost, right_op_stra, right_op_cost);
|
||||
auto new_cost = std::make_shared<Cost>(new_computation, new_commu_cost, decision);
|
||||
new_cost->communication_without_parameter_ = new_commu_without;
|
||||
new_cost->communication_with_partial_para_ =
|
||||
|
|
|
@ -46,6 +46,7 @@ extern bool FULLY_USE_DEVICES;
|
|||
extern bool ELEMENTWISE_OP_STRA_FOLLOW;
|
||||
extern bool MULTI_SUBGRAPHS;
|
||||
extern int32_t RUN_PHASE;
|
||||
extern bool TRIANGLE_STRATEGY_OVERWRITE;
|
||||
|
||||
class CostGraph {
|
||||
// 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have
|
||||
|
|
|
@ -64,6 +64,7 @@ void CostModelContext::ResetAlgoParameters() {
|
|||
tensor_slice_alignment_size_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE;
|
||||
fully_use_device_ = DEFAULT_FULLY_USE_DEVICES;
|
||||
elementwise_stra_follow_ = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;
|
||||
triangle_strategy_overwrite_ = DEFAULT_TRIANGLE_STRATEGY_OVERWRITE;
|
||||
}
|
||||
|
||||
void CostModelContext::set_costmodel_context_for_device(const std::string &device_target) {
|
||||
|
@ -133,6 +134,8 @@ void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) {
|
|||
elementwise_stra_follow_ = elementwise_follow;
|
||||
}
|
||||
|
||||
void CostModelContext::set_triangle_strategy_overwrite(bool overwrite) { triangle_strategy_overwrite_ = overwrite; }
|
||||
|
||||
void CostModelContext::set_run_phase(int32_t phase) { run_phase_ = phase; }
|
||||
|
||||
struct CostRegister {
|
||||
|
|
|
@ -44,6 +44,8 @@ namespace parallel {
|
|||
#define DEFAULT_RUN_PHASE 0
|
||||
#define TRAINING_PHASE 0
|
||||
#define INFERENCE_PHASE 1
|
||||
#define DEFAULT_TRIANGLE_STRATEGY_OVERWRITE true;
|
||||
|
||||
class CostModelContext {
|
||||
public:
|
||||
~CostModelContext() = default;
|
||||
|
@ -133,6 +135,9 @@ class CostModelContext {
|
|||
void set_elementwise_stra_follow(bool);
|
||||
bool elementwise_stra_follow() const { return elementwise_stra_follow_; }
|
||||
|
||||
void set_triangle_strategy_overwrite(bool);
|
||||
bool triangle_strategy_overwrite() const { return triangle_strategy_overwrite_; }
|
||||
|
||||
void set_run_phase(int32_t);
|
||||
int32_t run_phase() const { return run_phase_; }
|
||||
|
||||
|
@ -167,6 +172,10 @@ class CostModelContext {
|
|||
// MULTI_SUBGRAPHS
|
||||
bool is_multi_subgraphs_;
|
||||
|
||||
// In the recovery phase of DP algorithm, when encountering triangle structure,
|
||||
// whether overwrite the right-node strategy
|
||||
bool triangle_strategy_overwrite_;
|
||||
|
||||
int32_t run_phase_; // 0: 'training', 1: 'inference'
|
||||
|
||||
int32_t costmodel_allreduce_fusion_algorithm_;
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.parallel._utils import _reset_op_id as reset_op_id
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x):
|
||||
predict = self.network(x)
|
||||
return self.loss(predict)
|
||||
|
||||
class GradWarp(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWarp, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x):
|
||||
return grad_all(self.network)(x)
|
||||
|
||||
def test_triangle_strategy_consistency():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.mul1 = P.Mul().shard(((2, 4), (2, 4)))
|
||||
self.mul2 = P.Mul()
|
||||
self.ba1 = P.BiasAdd()
|
||||
self.weight = Parameter(Tensor(np.ones([128, 1000]), dtype=ms.float32), name="weight")
|
||||
self.bias = Parameter(Tensor(np.ones([1000]), dtype=ms.float32), name="bias")
|
||||
self.add = P.TensorAdd().shard(((1, 8), (1, 8)))
|
||||
self.relu = P.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
out = self.mul1(x, self.weight)
|
||||
mul_out = self.mul2(out, self.weight)
|
||||
ba_out = self.ba1(out, self.bias)
|
||||
ta_out = self.add(mul_out, ba_out)
|
||||
out = self.relu(ta_out)
|
||||
return out
|
||||
|
||||
size = 8
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
x = Tensor(np.ones([128, 1000]), dtype=ms.float32)
|
||||
net = NetWithLoss(Net())
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
reset_op_id()
|
||||
|
||||
_executor.compile(net, x, phase='train')
|
Loading…
Reference in New Issue