slice_activation_in_recompute

slice recompute activation
This commit is contained in:
yao_yf 2021-11-06 16:12:23 +08:00
parent 5c573e6d7d
commit 188d39da83
11 changed files with 501 additions and 26 deletions

View File

@ -21,6 +21,7 @@
#include "frontend/optimizer/irpass.h" #include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/anf_visitor.h" #include "frontend/optimizer/anf_visitor.h"
#include "frontend/parallel/context.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
namespace mindspore { namespace mindspore {
@ -47,6 +48,13 @@ class SetCellOutputNoRecompute : public AnfVisitor {
for (const auto &real_output : real_outputs) { for (const auto &real_output : real_outputs) {
// Set the attr of cnode in case of shared primitives. // Set the attr of cnode in case of shared primitives.
real_output->AddAttr(kAttrRecompute, MakeValue(false)); real_output->AddAttr(kAttrRecompute, MakeValue(false));
if (parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::SEMI_AUTO_PARALLEL ||
parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::AUTO_PARALLEL) {
auto prim = GetCNodePrimitive(real_output);
if (prim->HasAttr(kAttrSliceActivation) && GetValue<bool>(prim->GetAttr(kAttrSliceActivation))) {
real_output->AddAttr(kAttrSliceActivation, MakeValue(true));
}
}
} }
} }
fg->erase_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE); fg->erase_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE);

View File

@ -0,0 +1,309 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/optimizer/slice_activation_in_recompute.h"
#include <memory>
#include <queue>
#include <utility>
#include <list>
#include <vector>
#include <string>
#include <algorithm>
#include "mindspore/core/base/core_ops.h"
#include "utils/utils.h"
#include "frontend/parallel/tensor_layout/construct_operator.h"
#include "frontend/parallel/step_parallel.h"
namespace mindspore {
namespace opt {
namespace {
constexpr auto kGradientsFlag = "Gradients";
const int64_t max_loop_size = 100;
bool IsBpropNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return false;
}
return node->fullname_with_scope().find(kGradientsFlag) == 0;
}
CNodePtr CreateStridedSliceCNode(const parallel::Shape &begin, const parallel::Shape &end,
const parallel::Shape &strides, const AnfNodePtr &node) {
auto slice_op = parallel::CreateStridedSliceOp(0, begin, end, strides);
auto slice_input = parallel::CreateInput(slice_op, node, parallel::STRIDEDSLICE);
auto func_graph = node->func_graph();
CNodePtr new_node = func_graph->NewCNode(slice_input);
return new_node;
}
CNodePtr CreateAllGatherCNode(const AnfNodePtr &node, std::string group) {
auto op = parallel::CreateAllGatherOp(group);
auto allgather_input = parallel::CreateInput(op, node, "recompute_slice_allgather");
auto func_graph = node->func_graph();
CNodePtr new_node = func_graph->NewCNode(allgather_input);
return new_node;
}
parallel::Group InferRepeatedRankList(const CNodePtr &cnode) {
OperatorInfoPtr operator_info = cnode->user_data<parallel::OperatorInfo>();
std::vector<parallel::TensorInfo> output_info = operator_info->outputs_tensor_info();
if (output_info.size() != 1) {
MS_LOG(EXCEPTION) << "The output_info size is wrong, node is" << cnode->DebugString();
}
auto tensor_layout = output_info[0].tensor_layout();
auto tensor_map = tensor_layout.origin_tensor_map();
std::vector<parallel::Group> groups;
operator_info->CreateGroupByTensorMap(tensor_map.array(), &groups);
return groups[0];
}
bool IsDuplicateNode(const AnfNodePtr &node) {
if (!node->isa<CNode>()) {
return false;
}
if (node->cast<CNodePtr>()->HasAttr(kAttrDuplicated)) {
return true;
}
if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad)) {
auto manager = node->func_graph()->manager();
auto node_users = manager->node_users()[node];
return std::any_of(node_users.begin(), node_users.end(),
[](auto node_user) { return IsDuplicateNode(node_user.first); });
}
return false;
}
CNodePtr BpropInput(const CNodePtr cnode) {
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
if (IsBpropNode(cnode->input(i))) {
return cnode->input(i)->cast<CNodePtr>();
}
}
return nullptr;
}
std::vector<CNodePtr> NextBpropNodes(const CNodePtr cnode) {
std::vector<CNodePtr> next_brop_nodes;
auto manager = cnode->func_graph()->manager();
MS_EXCEPTION_IF_NULL(manager);
auto cnode_users = manager->node_users()[cnode];
std::queue<CNodePtr> cnode_queue;
for (auto &node : cnode_users) {
if (!node.first->isa<CNode>()) {
continue;
}
cnode_queue.push(node.first->cast<CNodePtr>());
}
int64_t loop_size = 0;
while (!cnode_queue.empty() && loop_size < max_loop_size) {
auto cur_cnode = cnode_queue.front();
cnode_queue.pop();
if (IsBpropNode(cur_cnode)) {
next_brop_nodes.push_back(cur_cnode);
continue;
}
auto cur_cnode_users = manager->node_users()[cur_cnode];
for (auto &node : cur_cnode_users) {
if (!node.first->isa<CNode>()) {
continue;
}
cnode_queue.push(node.first->cast<CNodePtr>());
}
loop_size++;
}
if (loop_size < max_loop_size) {
return next_brop_nodes;
}
return {};
}
void GroupingNextNodes(const CNodePtr &node, std::vector<std::pair<std::shared_ptr<AnfNode>, int>> *duplicate_users,
std::vector<std::pair<std::shared_ptr<AnfNode>, int>> *forward_users) {
auto manager = node->func_graph()->manager();
auto root_node = node;
auto node_users = manager->node_users()[root_node];
for (auto node_user : node_users) {
if (IsDuplicateNode(node_user.first)) {
duplicate_users->push_back(node_user);
} else {
forward_users->push_back(node_user);
}
}
}
void InsertSliceAllGatherNode(const std::vector<std::pair<std::shared_ptr<AnfNode>, int>> &node_users,
const std::pair<std::shared_ptr<AnfNode>, int> &forward_node_user,
const std::shared_ptr<CNode> &node, std::vector<CNodePtr> *slice_allgathers,
int64_t recompute_order_id) {
auto manager = node->func_graph()->manager();
MS_EXCEPTION_IF_NULL(manager);
auto output_shape = node->abstract()->BuildShape();
std::vector<int64_t> out_shape_element = output_shape->cast<abstract::ShapePtr>()->shape();
if (out_shape_element.empty()) {
return;
}
int64_t global_rank_id = parallel::g_device_manager->global_rank();
int64_t stage_num = parallel::g_device_manager->stage_num();
int64_t device_num = parallel::g_device_manager->DeviceNum();
int64_t stage_device_num = device_num / stage_num;
int64_t local_rank_id = global_rank_id % stage_device_num;
auto group = InferRepeatedRankList(node);
if (out_shape_element[0] % group.GetDevNum() != 0) {
MS_LOG(WARNING) << "The output_shape first dim:" << out_shape_element[0]
<< " cannot be divisible by the repeated size: " << group.GetDevNum()
<< "The slice would not activate to this node: " << node->DebugString();
return;
}
int64_t group_deivce_num = group.GetDevNum();
std::vector<int64_t> slice_begin(out_shape_element.size(), 0);
slice_begin[0] = (local_rank_id % group_deivce_num) * (out_shape_element[0] / group_deivce_num);
std::vector<int64_t> slice_end = out_shape_element;
slice_end[0] = (local_rank_id % group_deivce_num + 1) * (out_shape_element[0] / group_deivce_num);
std::vector<int64_t> slice_strides(out_shape_element.size(), 1);
CNodePtr slice_cnode = CreateStridedSliceCNode(slice_begin, slice_end, slice_strides, node);
slice_cnode->set_abstract(node->abstract()->Clone());
std::vector<int64_t> slice_shape = out_shape_element;
slice_shape[0] = out_shape_element[0] / group_deivce_num;
std::shared_ptr<abstract::BaseShape> slice_base_shape = std::make_shared<abstract::Shape>(slice_shape);
slice_cnode->abstract()->set_shape(slice_base_shape);
for (auto &node_user : node_users) {
manager->SetEdge(node_user.first, node_user.second, slice_cnode);
}
CNodePtr allgather_cnode = CreateAllGatherCNode(slice_cnode, group.name());
allgather_cnode->set_abstract(node->abstract()->Clone());
allgather_cnode->AddAttr("recompute_order", MakeValue(recompute_order_id));
if (node->HasPrimalAttr(parallel::MICRO)) {
allgather_cnode->AddPrimalAttr(parallel::MICRO, node->GetPrimalAttr(parallel::MICRO));
}
manager->Replace(slice_cnode, allgather_cnode);
slice_allgathers->push_back(allgather_cnode);
std::vector<AnfNodePtr> depend_inputs{NewValueNode(prim::kPrimDepend), forward_node_user.first, slice_cnode};
auto depend_node = node->func_graph()->NewCNode(depend_inputs);
depend_node->set_abstract(forward_node_user.first->abstract()->Clone());
depend_node->AddAttr("slice_forward_depend", MakeValue(true));
MS_EXCEPTION_IF_NULL(depend_node);
manager->Replace(forward_node_user.first, depend_node);
}
void InsertAllGatherDepend(const FuncGraphPtr &graph, const std::vector<CNodePtr> &slice_allgathers) {
auto manager = graph->manager();
auto last_allgather = slice_allgathers.back();
auto next_cnodes = NextBpropNodes(last_allgather);
CNodePtr next_cnode = nullptr;
if (!next_cnodes.empty()) {
MS_LOG(INFO) << "The next_cnodes is not empty.";
std::list<CNodePtr> orders = graph->GetOrderedCnodes();
for (auto &cnode : orders) {
if (std::find(next_cnodes.begin(), next_cnodes.end(), cnode) != next_cnodes.end()) {
next_cnode = cnode;
break;
}
}
}
for (size_t i = slice_allgathers.size() - 1; i > 0; --i) {
std::vector<AnfNodePtr> depend_inputs{NewValueNode(prim::kPrimDepend), slice_allgathers[i - 1]->input(1),
slice_allgathers[i]};
auto depend_node = graph->NewCNode(depend_inputs);
MS_EXCEPTION_IF_NULL(depend_node);
depend_node->set_abstract(slice_allgathers[i - 1]->input(1)->abstract()->Clone());
depend_node->AddAttr("slice_allgather_depend", MakeValue(i));
manager->SetEdge(slice_allgathers[i - 1], 1, depend_node);
}
if (next_cnode == nullptr) {
MS_LOG(WARNING) << "cannot find the bprop node in 100 loop";
return;
}
auto allgather_depend_node = BpropInput(next_cnode);
if (allgather_depend_node == nullptr) {
MS_LOG(WARNING) << "cannot find the bprob input for allgather to depend.";
return;
}
MS_LOG(INFO) << "Insert depend for last slice allgather. The depend node is: "
<< allgather_depend_node->DebugString();
std::vector<AnfNodePtr> depend_inputs{NewValueNode(prim::kPrimDepend), last_allgather->input(1),
allgather_depend_node};
auto depend_node = graph->NewCNode(depend_inputs);
MS_EXCEPTION_IF_NULL(depend_node);
depend_node->set_abstract(last_allgather->input(1)->abstract()->Clone());
depend_node->AddAttr("last_slice_allgather_depend", MakeValue(true));
manager->SetEdge(last_allgather, 1, depend_node);
}
} // namespace
void SliceRecomputedActivationNodes(const FuncGraphPtr &graph) {
if (parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::SEMI_AUTO_PARALLEL &&
parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::AUTO_PARALLEL) {
return;
}
MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
std::list<CNodePtr> orders = graph->GetOrderedCnodes();
std::vector<CNodePtr> origin_nodes_topological(orders.begin(), orders.end());
std::vector<CNodePtr> slice_allgathers;
int64_t recompute_order_id = 0;
for (auto &node : origin_nodes_topological) {
if (!node->HasAttr(kAttrSliceActivation) || IsPrimitiveCNode(node, prim::kPrimTranspose) ||
!node->has_user_data<parallel::OperatorInfo>()) {
continue;
}
auto node_users = manager->node_users()[node];
std::vector<std::pair<std::shared_ptr<AnfNode>, int>> duplicate_users;
std::vector<std::pair<std::shared_ptr<AnfNode>, int>> forward_users;
GroupingNextNodes(node, &duplicate_users, &forward_users);
if (duplicate_users.empty() || forward_users.empty()) {
continue;
}
InsertSliceAllGatherNode(duplicate_users, forward_users[0], node, &slice_allgathers, recompute_order_id);
recompute_order_id++;
}
if (slice_allgathers.size() == 0) {
return;
}
if (parallel::ParallelContext::GetInstance()->pipeline_stage_split_num() > 1) {
int64_t current_micro = -1;
std::vector<CNodePtr> stage_slice_allgathers;
for (auto &slice_allgather_node : slice_allgathers) {
if (!slice_allgather_node->HasPrimalAttr(parallel::MICRO)) {
MS_LOG(EXCEPTION) << "In pipeline parallel mode, cannot find 'micro' attributes in node.";
}
int64_t micro = GetValue<int64_t>(slice_allgather_node->GetPrimalAttr(parallel::MICRO));
if (micro > current_micro) {
if (current_micro != -1) {
MS_LOG(INFO) << "Insert allgather depends, micro is: " << current_micro;
InsertAllGatherDepend(graph, stage_slice_allgathers);
}
stage_slice_allgathers.clear();
stage_slice_allgathers.push_back(slice_allgather_node);
current_micro = micro;
} else if (micro == current_micro) {
stage_slice_allgathers.push_back(slice_allgather_node);
} else if (current_micro != -1) {
MS_LOG(EXCEPTION) << "The micro number dose not match the execution orders in pipeline parallel";
}
}
MS_LOG(INFO) << "Insert last stage allgather depends, micro is: " << current_micro;
InsertAllGatherDepend(graph, stage_slice_allgathers);
} else {
InsertAllGatherDepend(graph, slice_allgathers);
}
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,28 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_SLICE_ACTIVATION_IN_RECOMPUTE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_SLICE_ACTIVATION_IN_RECOMPUTE_H_
#include "ir/anf.h"
namespace mindspore {
namespace opt {
// Automatically insert duplicated recomputed nodes.
void SliceRecomputedActivationNodes(const FuncGraphPtr &graph);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_SLICE_ACTIVATION_IN_RECOMPUTE_H_

View File

@ -28,7 +28,7 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
using Args = std::vector<std::int64_t>; using Args = std::vector<std::int64_t>;
Operator CreateStridedSliceOp(int64_t value, const Shape &begin, const Shape &end, const Shape &strides);
class ConstructOperator { class ConstructOperator {
public: public:
const int64_t DEFAULT = 0; const int64_t DEFAULT = 0;

View File

@ -40,6 +40,7 @@
#include "frontend/parallel/cache_embedding/cache_embedding.h" #include "frontend/parallel/cache_embedding/cache_embedding.h"
#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h" #include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
#include "frontend/optimizer/recompute.h" #include "frontend/optimizer/recompute.h"
#include "frontend/optimizer/slice_activation_in_recompute.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "pipeline/jit/pipeline_split.h" #include "pipeline/jit/pipeline_split.h"
#include "pipeline/pynative/pynative_execute.h" #include "pipeline/pynative/pynative_execute.h"
@ -594,6 +595,12 @@ bool AddRecomputationPass(const ResourcePtr &res) {
return true; return true;
} }
bool SliceRecomputeActivationPass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
opt::SliceRecomputedActivationNodes(res->func_graph());
return true;
}
bool AddCacheEmbeddingPass(const ResourcePtr &res) { bool AddCacheEmbeddingPass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res); MS_EXCEPTION_IF_NULL(res);
#if ((defined ENABLE_CPU) && (!defined _WIN32)) #if ((defined ENABLE_CPU) && (!defined _WIN32))
@ -734,7 +741,8 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
{"tuple_transform", OptPassTransformGraphGroup}, {"tuple_transform", OptPassTransformGraphGroup},
{"add_cache_embedding", AddCacheEmbeddingPass}, {"add_cache_embedding", AddCacheEmbeddingPass},
{"add_recomputation", AddRecomputationPass}, {"add_recomputation", AddRecomputationPass},
{"cse_after_recomputation", OptAfterRecomputeGroup}}; {"cse_after_recomputation", OptAfterRecomputeGroup},
{"slice_recompute_activation", SliceRecomputeActivationPass}};
std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_a", OptPassAGroup}, {"opt_a", OptPassAGroup},

View File

@ -471,6 +471,7 @@ constexpr auto kAttrPadding = "padding";
constexpr auto kAttrNonTask = "non_task"; constexpr auto kAttrNonTask = "non_task";
constexpr auto kAttrIsGrad = "is_grad"; constexpr auto kAttrIsGrad = "is_grad";
constexpr auto kAttrRecompute = "recompute"; constexpr auto kAttrRecompute = "recompute";
constexpr auto kAttrSliceActivation = "slice_activation";
constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute"; constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute";
constexpr auto kAttrParallelDimInfo = "parallel_dim_info"; constexpr auto kAttrParallelDimInfo = "parallel_dim_info";
constexpr auto kAttrParallelFusionType = "parallel_fusion_type"; constexpr auto kAttrParallelFusionType = "parallel_fusion_type";

View File

@ -1499,6 +1499,16 @@ class Cell(Cell_):
for param in self.trainable_params(): for param in self.trainable_params():
param.parallel_optimizer_comm_recompute = parallel_optimizer_comm_recompute param.parallel_optimizer_comm_recompute = parallel_optimizer_comm_recompute
def _recompute_slice_activation(self, slice_activation=False):
"""
Slice the cell output which would remains in memory.
"""
for _, value in self._primitives.items():
if value:
value.add_prim_attr("slice_activation", slice_activation)
for cell in self.cells():
cell._recompute_slice_activation(slice_activation)
def _recompute(self, mode=True, output_recompute=False): def _recompute(self, mode=True, output_recompute=False):
""" """
Set the cell recomputed. Set the cell recomputed.
@ -1554,9 +1564,11 @@ class Cell(Cell_):
raise ValueError("Currently, the communication operator allgathers introduced by optimizer shard " raise ValueError("Currently, the communication operator allgathers introduced by optimizer shard "
"are not support recomputation in pipeline parallel.") "are not support recomputation in pipeline parallel.")
self._parallel_optimizer_comm_recompute(kwargs['parallel_optimizer_comm_recompute']) self._parallel_optimizer_comm_recompute(kwargs['parallel_optimizer_comm_recompute'])
if 'recompute_slice_activation' in kwargs.keys():
self._recompute_slice_activation(kwargs['recompute_slice_activation'])
for key, _ in kwargs.items(): for key, _ in kwargs.items():
if key not in ('mp_comm_recompute', 'parallel_optimizer_comm_recompute'): if key not in ('mp_comm_recompute', 'parallel_optimizer_comm_recompute', 'recompute_slice_activation'):
raise ValueError("Recompute keyword %s is not recognized!" % key) raise ValueError("Recompute keyword %s is not recognized!" % key)
def infer_param_pipeline_stage(self): def infer_param_pipeline_stage(self):

View File

@ -78,7 +78,6 @@ class OpParallelConfig(_Config):
Validator.check_positive_int(value, "model_parallel") Validator.check_positive_int(value, "model_parallel")
self._model_parallel = value self._model_parallel = value
class _PipeLineConfig(_Config): class _PipeLineConfig(_Config):
r""" r"""
PPConfig for the setting data parallel, model parallel PPConfig for the setting data parallel, model parallel

View File

@ -48,7 +48,8 @@ __all__ = [
"TransformerDecoderLayer", "TransformerDecoderLayer",
"Transformer", "Transformer",
"TransformerOpParallelConfig", "TransformerOpParallelConfig",
"EmbeddingOpParallelConfig"] "EmbeddingOpParallelConfig",
"TransformerRecomputeConfig"]
class EmbeddingOpParallelConfig(_Config): class EmbeddingOpParallelConfig(_Config):
@ -112,6 +113,76 @@ class EmbeddingOpParallelConfig(_Config):
return self._dp_mp_config return self._dp_mp_config
class TransformerRecomputeConfig(_Config):
r"""
TransformerRecomputeConfig for the setting recompute attributes for encoder/decoder layers.
Args:
recompute (bool): Enable recomputation of the transformer block or not. Default: False.
parallel_optimizer_comm_recompute (bool): The model parallel way. Default: 1
parallel_optimizer_comm_recompute (bool): Specifies whether the communication operator allgathers
introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode.
Default: False.
mp_comm_recompute (bool): Specifies whether the model parallel communication operators
in the cell are recomputed in auto parallel or semi auto parallel mode. Default: True.
recompute_slice_activation (bool): Slice the cell output which would remains in memory. Default: False.
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> config=TransformerRecomputeConfig(recompute=True, parallel_optimizer_comm_recompute=True,
>>> mp_comm_recompute=True, recompute_slice_activation=True)
"""
def __init__(self, recompute=False, parallel_optimizer_comm_recompute=False,
mp_comm_recompute=True, recompute_slice_activation=False):
Validator.check_bool(recompute, "recompute")
Validator.check_bool(parallel_optimizer_comm_recompute, "parallel_optimizer_comm_recompute")
Validator.check_bool(mp_comm_recompute, "mp_comm_recompute")
Validator.check_bool(recompute_slice_activation, "recompute_slice_activation")
self._recompute = recompute
self._parallel_optimizer_comm_recompute = parallel_optimizer_comm_recompute
self._mp_comm_recompute = mp_comm_recompute
self._recompute_slice_activation = recompute_slice_activation
@property
def recompute(self):
return self._recompute
@recompute.setter
def recompute(self, value):
Validator.check_bool(value, "recompute")
self._recompute = value
@property
def parallel_optimizer_comm_recompute(self):
return self._parallel_optimizer_comm_recompute
@parallel_optimizer_comm_recompute.setter
def parallel_optimizer_comm_recompute(self, value):
Validator.check_bool(value, "parallel_optimizer_comm_recompute")
self._parallel_optimizer_comm_recompute = value
@property
def mp_comm_recompute(self):
return self._mp_comm_recompute
@mp_comm_recompute.setter
def mp_comm_recompute(self, value):
Validator.check_bool(value, "mp_comm_recompute")
self._mp_comm_recompute = value
@property
def recompute_slice_activation(self):
return self._recompute_slice_activation
@recompute_slice_activation.setter
def recompute_slice_activation(self, value):
Validator.check_bool(value, "recompute_slice_activation")
self._recompute_slice_activation = value
_DEFALUT_TRANSFORMER_RECOMPUTE_CONFIG = TransformerRecomputeConfig()
class TransformerOpParallelConfig(_Config): class TransformerOpParallelConfig(_Config):
r""" r"""
TransformerOpParallelConfig for the setting global data parallel, model parallel and fusion group. TransformerOpParallelConfig for the setting global data parallel, model parallel and fusion group.
@ -131,17 +202,21 @@ class TransformerOpParallelConfig(_Config):
micro_batch_num (int): The microe size of the batches for the pipeline training. Default: 1. micro_batch_num (int): The microe size of the batches for the pipeline training. Default: 1.
optimizer_shard (bool): Whether to enable optimizer shard. Default False. optimizer_shard (bool): Whether to enable optimizer shard. Default False.
gradient_aggregation_group (int): The fusion group size of the optimizer state sharding. Default: 4. gradient_aggregation_group (int): The fusion group size of the optimizer state sharding. Default: 4.
recompute (bool): Enable recomputation of the transformer block or not. Default: False. recompute (Union[TransformerRecomputeConfig, bool]): The configuration of recomputation for
the transformer block. Default: The default configuration of TransformerRecomputeConfig.
vocab_emb_dp (bool): Shard embedding in model parallel or data parallel. Default: True. vocab_emb_dp (bool): Shard embedding in model parallel or data parallel. Default: True.
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``Ascend`` ``GPU``
Examples: Examples:
>>> config=TransformerOpParallelConfig(data_parallel=1, model_parallel=1) >>> recompute_config=TransformerRecomputeConfig(recompute=True, parallel_optimizer_comm_recompute=True,
>>> mp_comm_recompute=True, recompute_slice_activation=True)
>>> config=TransformerOpParallelConfig(data_parallel=1, model_parallel=1, recompute=recompute_config)
""" """
def __init__(self, data_parallel=1, model_parallel=1, pipeline_stage=1, micro_batch_num=1, recompute=False, def __init__(self, data_parallel=1, model_parallel=1, pipeline_stage=1, micro_batch_num=1,
recompute=_DEFALUT_TRANSFORMER_RECOMPUTE_CONFIG,
optimizer_shard=False, gradient_aggregation_group=4, vocab_emb_dp=True): optimizer_shard=False, gradient_aggregation_group=4, vocab_emb_dp=True):
self.recompute = recompute self.recompute = recompute
self.optimizer_shard = optimizer_shard self.optimizer_shard = optimizer_shard
@ -156,7 +231,10 @@ class TransformerOpParallelConfig(_Config):
@recompute.setter @recompute.setter
def recompute(self, value): def recompute(self, value):
Validator.check_bool(value, "recompute") if not isinstance(value, TransformerRecomputeConfig) and not isinstance(value, bool):
raise TypeError(f"recompute should be a TransformerRecomputeConfig/bool, but got {type(value).__name__}.")
if isinstance(value, bool):
logger.warning(f"TransformerRecomputeConfig is recommended as the recompute configuration type.")
self._recompute = value self._recompute = value
@property @property
@ -1726,8 +1804,15 @@ def _get_lambda_func(total_layer=None):
dis = max(int(layers / parallel_config.gradient_aggregation_group), 1) dis = max(int(layers / parallel_config.gradient_aggregation_group), 1)
network.set_comm_fusion(int((layer_id + offset) / dis) + 1) network.set_comm_fusion(int((layer_id + offset) / dis) + 1)
# Used for enabling recomputation of the block # Used for enabling recomputation of the block
if parallel_config.recompute: if isinstance(parallel_config.recompute, bool):
network.recompute() if parallel_config.recompute:
network.recompute()
else:
if parallel_config.recompute.recompute:
network.recompute(parallel_optimizer_comm_recompute=
parallel_config.recompute.parallel_optimizer_comm_recompute,
mp_comm_recompute=parallel_config.recompute.mp_comm_recompute,
recompute_slice_activation=parallel_config.recompute.recompute_slice_activation)
return _set_parallel_configure_for_layer return _set_parallel_configure_for_layer

View File

@ -27,8 +27,8 @@ class MatMulCell(nn.Cell):
def __init__(self): def __init__(self):
super(MatMulCell, self).__init__() super(MatMulCell, self).__init__()
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.matmul0 = P.MatMul() self.matmul0 = P.MatMul(transpose_b=True)
self.weight = Parameter(initializer("ones", [128, 64], ms.float32), name="weight") self.weight = Parameter(initializer("ones", [64, 128], ms.float32), name="weight")
self.relu = P.ReLU().shard(((1, 8),)) self.relu = P.ReLU().shard(((1, 8),))
def construct(self, x): def construct(self, x):
x = self.matmul0(x, self.weight) x = self.matmul0(x, self.weight)
@ -37,25 +37,26 @@ class MatMulCell(nn.Cell):
return x return x
class DenseMutMulNet(nn.Cell): class DenseMutMulNet(nn.Cell):
def __init__(self): def __init__(self, mp_comm_recompute=True, recompute_slice_activation=False):
super(DenseMutMulNet, self).__init__() super(DenseMutMulNet, self).__init__()
self.fc1 = nn.Dense(128, 768, activation='relu') self.fc1 = nn.Dense(128, 768, activation='relu')
self.fc2 = nn.Dense(128, 768, activation='relu') self.fc2 = nn.Dense(128, 768, activation='relu')
self.fc3 = nn.Dense(128, 768, activation='relu') self.fc3 = nn.Dense(128, 768, activation='relu')
self.fc4 = nn.Dense(768, 768, activation='relu') self.fc4 = nn.Dense(768, 768, activation='relu')
self.fc1.matmul.shard(((1, 1), (1, 8))) self.fc1.matmul.shard(((1, 1), (8, 1)))
self.fc2.matmul.shard(((1, 1), (1, 8))) self.fc2.matmul.shard(((1, 1), (8, 1)))
self.fc3.matmul.shard(((1, 1), (1, 8))) self.fc3.matmul.shard(((1, 1), (8, 1)))
self.relu4 = nn.ReLU() self.relu4 = nn.ReLU()
self.relu5 = nn.ReLU() self.relu5 = nn.ReLU()
self.transpose = P.Transpose() self.transpose = P.Transpose()
self.matmul1 = P.MatMul() self.matmul1 = P.MatMul()
self.matmul2 = P.MatMul() self.matmul2 = P.MatMul()
self.matmul_cell = MatMulCell() self.matmul_cell = MatMulCell()
self.fc1.recompute(mp_comm_recompute=False) self.fc1.recompute(mp_comm_recompute=mp_comm_recompute, recompute_slice_activation=recompute_slice_activation)
self.fc2.recompute(mp_comm_recompute=False) self.fc2.recompute(mp_comm_recompute=mp_comm_recompute, recompute_slice_activation=recompute_slice_activation)
self.fc3.recompute(mp_comm_recompute=False) self.fc3.recompute(mp_comm_recompute=mp_comm_recompute, recompute_slice_activation=recompute_slice_activation)
self.matmul_cell.recompute(mp_comm_recompute=False) self.matmul_cell.recompute(mp_comm_recompute=mp_comm_recompute,
recompute_slice_activation=recompute_slice_activation)
def construct(self, x): def construct(self, x):
x = self.matmul_cell(x) x = self.matmul_cell(x)
@ -68,16 +69,40 @@ class DenseMutMulNet(nn.Cell):
s = self.fc4(s) s = self.fc4(s)
return s return s
def compile_net(mp_comm_recompute, recompute_slice_activation):
def test_dmnet_train_step():
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
_Context().set_backend_policy("vm") _Context().set_backend_policy("vm")
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8)
input_ = Tensor(np.ones([64, 128]).astype(np.float32) * 0.01) input_ = Tensor(np.ones([64, 128]).astype(np.float32) * 0.01)
label = Tensor(np.zeros([32, 768]).astype(np.float32)) label = Tensor(np.zeros([32, 768]).astype(np.float32))
net = train_step_with_loss_warp(DenseMutMulNet()) net = train_step_with_loss_warp(DenseMutMulNet(mp_comm_recompute=mp_comm_recompute,
recompute_slice_activation=recompute_slice_activation))
net.set_auto_parallel() net.set_auto_parallel()
net.set_train() net.set_train()
_cell_graph_executor.compile(net, input_, label) _cell_graph_executor.compile(net, input_, label)
_Context().set_backend_policy("ge") _Context().set_backend_policy("ge")
def test_dmnet_train_step_mp_recompute():
"""
Feature: test recompute interface.
Description: test model parallel communication not recompute.
Expectation: compile without error.
"""
compile_net(False, False)
def test_dmnet_train_step_recompute_activation_slice():
"""
Feature: test recompute interface.
Description: test slicing recompute cell output.
Expectation: compile without error.
"""
compile_net(True, True)
def test_dmnet_train_step_mp_recompute_recompute_activation_slice():
"""
Feature: test recompute interface.
Description: test model parallel communication not recompute and slicing recompute cell output.
Expectation: compile without error.
"""
compile_net(False, True)

View File

@ -652,9 +652,9 @@ def test_transformer_parallel_config():
with pytest.raises(TypeError): with pytest.raises(TypeError):
parallel_test_config.recompute = 1 parallel_test_config.recompute = 1
parallel_test_config.recompute = False parallel_test_config.recompute.recompute = False
assert not parallel_test_config.recompute assert not parallel_test_config.recompute.recompute
def test_parallel_config(): def test_parallel_config():