diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.h index a60cbc04287..ea0b3a73e1e 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.h @@ -246,14 +246,15 @@ struct SourceEliminationDecision : public Decision { */ struct TriangleEliminationDecision : public Decision { TriangleEliminationDecision(StrategyPtr elimi_stra, CostPtr elimi_op_cost, CostPtr l_edge_cost, CostPtr r_edge_cost, - StrategyPtr left_stra, CostPtr l_node_cost, StrategyPtr right_stra) + StrategyPtr left_stra, CostPtr l_node_cost, StrategyPtr right_stra, CostPtr r_node_cost) : eliminated_op_strategy_(std::move(elimi_stra)), eliminated_op_cost_(std::move(elimi_op_cost)), left_edge_cost_(std::move(l_edge_cost)), right_edge_cost_(std::move(r_edge_cost)), left_node_strategy_(std::move(left_stra)), left_node_cost_(std::move(l_node_cost)), - right_node_strategy_(std::move(right_stra)) { + right_node_strategy_(std::move(right_stra)), + right_node_cost_(std::move(r_node_cost)) { type_ = DecisionType::TRIANGLE_ELIMINATION; } @@ -264,6 +265,7 @@ struct TriangleEliminationDecision : public Decision { StrategyPtr left_node_strategy_; CostPtr left_node_cost_; StrategyPtr right_node_strategy_; + CostPtr right_node_cost_; MS_DECLARE_PARENT(TriangleEliminationDecision, Decision); }; diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc index b49ca05f3e1..65b576cef9b 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc @@ -199,9 +199,16 @@ Status RecoverStrategy(std::vector eliminations) { eliminated_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_); left_edge->set_selected_cost(decision->left_edge_cost_); right_edge->set_selected_cost(decision->right_edge_cost_); - // Since Triangle is eliminated into 'left_node', only 'left_node' is needed to recover the strategy. + // 'left_node' recovers the strategy. left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_); - right_node->CheckSelectedStrategy(decision->right_node_strategy_); + if (TRIANGLE_STRATEGY_OVERWRITE) { + // 'right_node' recovers the strategy. + MS_LOG(INFO) << "Overwrite the right-node: " << right_node->name() << " in recovering triangle elimination."; + right_node->SetSelectedStrategyAndCost(decision->right_node_strategy_, decision->right_node_cost_); + } else { + // In this case, 'right_node' is not overwriten strategy, and it checks strategy consistency. + right_node->CheckSelectedStrategy(decision->right_node_strategy_); + } MS_LOG(INFO) << "Recover triangleElimination succeeded."; } else if ((*rit)->isa()) { auto elimination = (*rit)->cast(); diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc index 5d57718110b..9837edf3b6f 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc @@ -40,6 +40,7 @@ bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES; bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; int32_t RUN_PHASE = DEFAULT_RUN_PHASE; +bool TRIANGLE_STRATEGY_OVERWRITE = DEFAULT_TRIANGLE_STRATEGY_OVERWRITE; void CostGraph::SetDeviceMemoryAndCostParameter() { MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); @@ -154,6 +155,14 @@ void CostGraph::SetDeviceMemoryAndCostParameter() { MS_LOG(INFO) << "multi_subgraphs: false."; } + auto overwrite = CostModelContext::GetInstance()->triangle_strategy_overwrite(); + TRIANGLE_STRATEGY_OVERWRITE = overwrite; + if (TRIANGLE_STRATEGY_OVERWRITE) { + MS_LOG(INFO) << "triangle_strategy_overwrite: true."; + } else { + MS_LOG(INFO) << "triangle_strategy_overwrite: false."; + } + // RUN_PHASE auto phase = CostModelContext::GetInstance()->run_phase(); if (phase != 0 && phase != 1) { @@ -1294,8 +1303,17 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ + left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_; - auto decision = std::make_shared( - elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost, left_op_stra, left_node_cost, right_op_stra); + if (TRIANGLE_STRATEGY_OVERWRITE) { + new_computation += right_op_cost->computation_cost_; + new_memory += right_op_cost->memory_with_reuse_; + new_commu_cost += right_op_cost->communication_cost_; + new_commu_forward += right_op_cost->communication_forward_; + new_commu_without += right_op_cost->communication_without_parameter_; + } + + auto decision = + std::make_shared(elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost, + left_op_stra, left_node_cost, right_op_stra, right_op_cost); auto new_cost = std::make_shared(new_computation, new_commu_cost, decision); new_cost->communication_without_parameter_ = new_commu_without; new_cost->communication_with_partial_para_ = diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h index 99d45dfb0a4..46487752ec0 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h @@ -46,6 +46,7 @@ extern bool FULLY_USE_DEVICES; extern bool ELEMENTWISE_OP_STRA_FOLLOW; extern bool MULTI_SUBGRAPHS; extern int32_t RUN_PHASE; +extern bool TRIANGLE_STRATEGY_OVERWRITE; class CostGraph { // 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have diff --git a/mindspore/ccsrc/frontend/parallel/costmodel_context.cc b/mindspore/ccsrc/frontend/parallel/costmodel_context.cc index 536895c8deb..e3383ad58d8 100644 --- a/mindspore/ccsrc/frontend/parallel/costmodel_context.cc +++ b/mindspore/ccsrc/frontend/parallel/costmodel_context.cc @@ -64,6 +64,7 @@ void CostModelContext::ResetAlgoParameters() { tensor_slice_alignment_size_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; fully_use_device_ = DEFAULT_FULLY_USE_DEVICES; elementwise_stra_follow_ = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; + triangle_strategy_overwrite_ = DEFAULT_TRIANGLE_STRATEGY_OVERWRITE; } void CostModelContext::set_costmodel_context_for_device(const std::string &device_target) { @@ -133,6 +134,8 @@ void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) { elementwise_stra_follow_ = elementwise_follow; } +void CostModelContext::set_triangle_strategy_overwrite(bool overwrite) { triangle_strategy_overwrite_ = overwrite; } + void CostModelContext::set_run_phase(int32_t phase) { run_phase_ = phase; } struct CostRegister { diff --git a/mindspore/ccsrc/frontend/parallel/costmodel_context.h b/mindspore/ccsrc/frontend/parallel/costmodel_context.h index b1668e13ef6..e809b53e33f 100644 --- a/mindspore/ccsrc/frontend/parallel/costmodel_context.h +++ b/mindspore/ccsrc/frontend/parallel/costmodel_context.h @@ -44,6 +44,8 @@ namespace parallel { #define DEFAULT_RUN_PHASE 0 #define TRAINING_PHASE 0 #define INFERENCE_PHASE 1 +#define DEFAULT_TRIANGLE_STRATEGY_OVERWRITE true; + class CostModelContext { public: ~CostModelContext() = default; @@ -133,6 +135,9 @@ class CostModelContext { void set_elementwise_stra_follow(bool); bool elementwise_stra_follow() const { return elementwise_stra_follow_; } + void set_triangle_strategy_overwrite(bool); + bool triangle_strategy_overwrite() const { return triangle_strategy_overwrite_; } + void set_run_phase(int32_t); int32_t run_phase() const { return run_phase_; } @@ -167,6 +172,10 @@ class CostModelContext { // MULTI_SUBGRAPHS bool is_multi_subgraphs_; + // In the recovery phase of DP algorithm, when encountering triangle structure, + // whether overwrite the right-node strategy + bool triangle_strategy_overwrite_; + int32_t run_phase_; // 0: 'training', 1: 'inference' int32_t costmodel_allreduce_fusion_algorithm_; diff --git a/tests/ut/python/parallel/test_auto_parallel_triangle_overwrite.py b/tests/ut/python/parallel/test_auto_parallel_triangle_overwrite.py new file mode 100644 index 00000000000..1436e1361eb --- /dev/null +++ b/tests/ut/python/parallel/test_auto_parallel_triangle_overwrite.py @@ -0,0 +1,73 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import mindspore as ms +import mindspore.nn as nn +from mindspore.common.api import _executor +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.parallel._utils import _reset_op_id as reset_op_id +from mindspore import context, Tensor, Parameter +from tests.ut.python.ops.test_math_ops import VirtualLoss + +grad_all = C.GradOperation(get_all=True) + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x): + predict = self.network(x) + return self.loss(predict) + +class GradWarp(nn.Cell): + def __init__(self, network): + super(GradWarp, self).__init__() + self.network = network + + def construct(self, x): + return grad_all(self.network)(x) + +def test_triangle_strategy_consistency(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.mul1 = P.Mul().shard(((2, 4), (2, 4))) + self.mul2 = P.Mul() + self.ba1 = P.BiasAdd() + self.weight = Parameter(Tensor(np.ones([128, 1000]), dtype=ms.float32), name="weight") + self.bias = Parameter(Tensor(np.ones([1000]), dtype=ms.float32), name="bias") + self.add = P.TensorAdd().shard(((1, 8), (1, 8))) + self.relu = P.ReLU() + + def construct(self, x): + out = self.mul1(x, self.weight) + mul_out = self.mul2(out, self.weight) + ba_out = self.ba1(out, self.bias) + ta_out = self.add(mul_out, ba_out) + out = self.relu(ta_out) + return out + + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0) + x = Tensor(np.ones([128, 1000]), dtype=ms.float32) + net = NetWithLoss(Net()) + context.set_auto_parallel_context(parallel_mode="auto_parallel") + net.set_auto_parallel() + reset_op_id() + + _executor.compile(net, x, phase='train')