forked from mindspore-Ecosystem/mindspore
parent
5c573e6d7d
commit
188d39da83
|
@ -21,6 +21,7 @@
|
|||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/optimizer/anf_visitor.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -47,6 +48,13 @@ class SetCellOutputNoRecompute : public AnfVisitor {
|
|||
for (const auto &real_output : real_outputs) {
|
||||
// Set the attr of cnode in case of shared primitives.
|
||||
real_output->AddAttr(kAttrRecompute, MakeValue(false));
|
||||
if (parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::SEMI_AUTO_PARALLEL ||
|
||||
parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::AUTO_PARALLEL) {
|
||||
auto prim = GetCNodePrimitive(real_output);
|
||||
if (prim->HasAttr(kAttrSliceActivation) && GetValue<bool>(prim->GetAttr(kAttrSliceActivation))) {
|
||||
real_output->AddAttr(kAttrSliceActivation, MakeValue(true));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
fg->erase_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE);
|
||||
|
|
|
@ -0,0 +1,309 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/optimizer/slice_activation_in_recompute.h"
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
#include <list>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "mindspore/core/base/core_ops.h"
|
||||
#include "utils/utils.h"
|
||||
#include "frontend/parallel/tensor_layout/construct_operator.h"
|
||||
#include "frontend/parallel/step_parallel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr auto kGradientsFlag = "Gradients";
|
||||
const int64_t max_loop_size = 100;
|
||||
bool IsBpropNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
return node->fullname_with_scope().find(kGradientsFlag) == 0;
|
||||
}
|
||||
|
||||
CNodePtr CreateStridedSliceCNode(const parallel::Shape &begin, const parallel::Shape &end,
|
||||
const parallel::Shape &strides, const AnfNodePtr &node) {
|
||||
auto slice_op = parallel::CreateStridedSliceOp(0, begin, end, strides);
|
||||
auto slice_input = parallel::CreateInput(slice_op, node, parallel::STRIDEDSLICE);
|
||||
auto func_graph = node->func_graph();
|
||||
CNodePtr new_node = func_graph->NewCNode(slice_input);
|
||||
return new_node;
|
||||
}
|
||||
|
||||
CNodePtr CreateAllGatherCNode(const AnfNodePtr &node, std::string group) {
|
||||
auto op = parallel::CreateAllGatherOp(group);
|
||||
auto allgather_input = parallel::CreateInput(op, node, "recompute_slice_allgather");
|
||||
auto func_graph = node->func_graph();
|
||||
CNodePtr new_node = func_graph->NewCNode(allgather_input);
|
||||
return new_node;
|
||||
}
|
||||
|
||||
parallel::Group InferRepeatedRankList(const CNodePtr &cnode) {
|
||||
OperatorInfoPtr operator_info = cnode->user_data<parallel::OperatorInfo>();
|
||||
std::vector<parallel::TensorInfo> output_info = operator_info->outputs_tensor_info();
|
||||
if (output_info.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "The output_info size is wrong, node is" << cnode->DebugString();
|
||||
}
|
||||
auto tensor_layout = output_info[0].tensor_layout();
|
||||
auto tensor_map = tensor_layout.origin_tensor_map();
|
||||
std::vector<parallel::Group> groups;
|
||||
operator_info->CreateGroupByTensorMap(tensor_map.array(), &groups);
|
||||
return groups[0];
|
||||
}
|
||||
|
||||
bool IsDuplicateNode(const AnfNodePtr &node) {
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
if (node->cast<CNodePtr>()->HasAttr(kAttrDuplicated)) {
|
||||
return true;
|
||||
}
|
||||
if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad)) {
|
||||
auto manager = node->func_graph()->manager();
|
||||
auto node_users = manager->node_users()[node];
|
||||
return std::any_of(node_users.begin(), node_users.end(),
|
||||
[](auto node_user) { return IsDuplicateNode(node_user.first); });
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
CNodePtr BpropInput(const CNodePtr cnode) {
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
if (IsBpropNode(cnode->input(i))) {
|
||||
return cnode->input(i)->cast<CNodePtr>();
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<CNodePtr> NextBpropNodes(const CNodePtr cnode) {
|
||||
std::vector<CNodePtr> next_brop_nodes;
|
||||
auto manager = cnode->func_graph()->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto cnode_users = manager->node_users()[cnode];
|
||||
std::queue<CNodePtr> cnode_queue;
|
||||
for (auto &node : cnode_users) {
|
||||
if (!node.first->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
cnode_queue.push(node.first->cast<CNodePtr>());
|
||||
}
|
||||
int64_t loop_size = 0;
|
||||
while (!cnode_queue.empty() && loop_size < max_loop_size) {
|
||||
auto cur_cnode = cnode_queue.front();
|
||||
cnode_queue.pop();
|
||||
if (IsBpropNode(cur_cnode)) {
|
||||
next_brop_nodes.push_back(cur_cnode);
|
||||
continue;
|
||||
}
|
||||
auto cur_cnode_users = manager->node_users()[cur_cnode];
|
||||
for (auto &node : cur_cnode_users) {
|
||||
if (!node.first->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
cnode_queue.push(node.first->cast<CNodePtr>());
|
||||
}
|
||||
loop_size++;
|
||||
}
|
||||
if (loop_size < max_loop_size) {
|
||||
return next_brop_nodes;
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
void GroupingNextNodes(const CNodePtr &node, std::vector<std::pair<std::shared_ptr<AnfNode>, int>> *duplicate_users,
|
||||
std::vector<std::pair<std::shared_ptr<AnfNode>, int>> *forward_users) {
|
||||
auto manager = node->func_graph()->manager();
|
||||
auto root_node = node;
|
||||
auto node_users = manager->node_users()[root_node];
|
||||
for (auto node_user : node_users) {
|
||||
if (IsDuplicateNode(node_user.first)) {
|
||||
duplicate_users->push_back(node_user);
|
||||
} else {
|
||||
forward_users->push_back(node_user);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void InsertSliceAllGatherNode(const std::vector<std::pair<std::shared_ptr<AnfNode>, int>> &node_users,
|
||||
const std::pair<std::shared_ptr<AnfNode>, int> &forward_node_user,
|
||||
const std::shared_ptr<CNode> &node, std::vector<CNodePtr> *slice_allgathers,
|
||||
int64_t recompute_order_id) {
|
||||
auto manager = node->func_graph()->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto output_shape = node->abstract()->BuildShape();
|
||||
std::vector<int64_t> out_shape_element = output_shape->cast<abstract::ShapePtr>()->shape();
|
||||
if (out_shape_element.empty()) {
|
||||
return;
|
||||
}
|
||||
int64_t global_rank_id = parallel::g_device_manager->global_rank();
|
||||
int64_t stage_num = parallel::g_device_manager->stage_num();
|
||||
int64_t device_num = parallel::g_device_manager->DeviceNum();
|
||||
int64_t stage_device_num = device_num / stage_num;
|
||||
int64_t local_rank_id = global_rank_id % stage_device_num;
|
||||
auto group = InferRepeatedRankList(node);
|
||||
if (out_shape_element[0] % group.GetDevNum() != 0) {
|
||||
MS_LOG(WARNING) << "The output_shape first dim:" << out_shape_element[0]
|
||||
<< " cannot be divisible by the repeated size: " << group.GetDevNum()
|
||||
<< "The slice would not activate to this node: " << node->DebugString();
|
||||
return;
|
||||
}
|
||||
int64_t group_deivce_num = group.GetDevNum();
|
||||
std::vector<int64_t> slice_begin(out_shape_element.size(), 0);
|
||||
slice_begin[0] = (local_rank_id % group_deivce_num) * (out_shape_element[0] / group_deivce_num);
|
||||
std::vector<int64_t> slice_end = out_shape_element;
|
||||
slice_end[0] = (local_rank_id % group_deivce_num + 1) * (out_shape_element[0] / group_deivce_num);
|
||||
std::vector<int64_t> slice_strides(out_shape_element.size(), 1);
|
||||
CNodePtr slice_cnode = CreateStridedSliceCNode(slice_begin, slice_end, slice_strides, node);
|
||||
slice_cnode->set_abstract(node->abstract()->Clone());
|
||||
std::vector<int64_t> slice_shape = out_shape_element;
|
||||
slice_shape[0] = out_shape_element[0] / group_deivce_num;
|
||||
std::shared_ptr<abstract::BaseShape> slice_base_shape = std::make_shared<abstract::Shape>(slice_shape);
|
||||
slice_cnode->abstract()->set_shape(slice_base_shape);
|
||||
for (auto &node_user : node_users) {
|
||||
manager->SetEdge(node_user.first, node_user.second, slice_cnode);
|
||||
}
|
||||
|
||||
CNodePtr allgather_cnode = CreateAllGatherCNode(slice_cnode, group.name());
|
||||
allgather_cnode->set_abstract(node->abstract()->Clone());
|
||||
allgather_cnode->AddAttr("recompute_order", MakeValue(recompute_order_id));
|
||||
if (node->HasPrimalAttr(parallel::MICRO)) {
|
||||
allgather_cnode->AddPrimalAttr(parallel::MICRO, node->GetPrimalAttr(parallel::MICRO));
|
||||
}
|
||||
manager->Replace(slice_cnode, allgather_cnode);
|
||||
slice_allgathers->push_back(allgather_cnode);
|
||||
|
||||
std::vector<AnfNodePtr> depend_inputs{NewValueNode(prim::kPrimDepend), forward_node_user.first, slice_cnode};
|
||||
auto depend_node = node->func_graph()->NewCNode(depend_inputs);
|
||||
depend_node->set_abstract(forward_node_user.first->abstract()->Clone());
|
||||
depend_node->AddAttr("slice_forward_depend", MakeValue(true));
|
||||
MS_EXCEPTION_IF_NULL(depend_node);
|
||||
manager->Replace(forward_node_user.first, depend_node);
|
||||
}
|
||||
|
||||
void InsertAllGatherDepend(const FuncGraphPtr &graph, const std::vector<CNodePtr> &slice_allgathers) {
|
||||
auto manager = graph->manager();
|
||||
auto last_allgather = slice_allgathers.back();
|
||||
auto next_cnodes = NextBpropNodes(last_allgather);
|
||||
CNodePtr next_cnode = nullptr;
|
||||
if (!next_cnodes.empty()) {
|
||||
MS_LOG(INFO) << "The next_cnodes is not empty.";
|
||||
std::list<CNodePtr> orders = graph->GetOrderedCnodes();
|
||||
for (auto &cnode : orders) {
|
||||
if (std::find(next_cnodes.begin(), next_cnodes.end(), cnode) != next_cnodes.end()) {
|
||||
next_cnode = cnode;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = slice_allgathers.size() - 1; i > 0; --i) {
|
||||
std::vector<AnfNodePtr> depend_inputs{NewValueNode(prim::kPrimDepend), slice_allgathers[i - 1]->input(1),
|
||||
slice_allgathers[i]};
|
||||
auto depend_node = graph->NewCNode(depend_inputs);
|
||||
MS_EXCEPTION_IF_NULL(depend_node);
|
||||
depend_node->set_abstract(slice_allgathers[i - 1]->input(1)->abstract()->Clone());
|
||||
depend_node->AddAttr("slice_allgather_depend", MakeValue(i));
|
||||
manager->SetEdge(slice_allgathers[i - 1], 1, depend_node);
|
||||
}
|
||||
|
||||
if (next_cnode == nullptr) {
|
||||
MS_LOG(WARNING) << "cannot find the bprop node in 100 loop";
|
||||
return;
|
||||
}
|
||||
auto allgather_depend_node = BpropInput(next_cnode);
|
||||
if (allgather_depend_node == nullptr) {
|
||||
MS_LOG(WARNING) << "cannot find the bprob input for allgather to depend.";
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Insert depend for last slice allgather. The depend node is: "
|
||||
<< allgather_depend_node->DebugString();
|
||||
std::vector<AnfNodePtr> depend_inputs{NewValueNode(prim::kPrimDepend), last_allgather->input(1),
|
||||
allgather_depend_node};
|
||||
auto depend_node = graph->NewCNode(depend_inputs);
|
||||
MS_EXCEPTION_IF_NULL(depend_node);
|
||||
depend_node->set_abstract(last_allgather->input(1)->abstract()->Clone());
|
||||
depend_node->AddAttr("last_slice_allgather_depend", MakeValue(true));
|
||||
manager->SetEdge(last_allgather, 1, depend_node);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void SliceRecomputedActivationNodes(const FuncGraphPtr &graph) {
|
||||
if (parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::SEMI_AUTO_PARALLEL &&
|
||||
parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::AUTO_PARALLEL) {
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::list<CNodePtr> orders = graph->GetOrderedCnodes();
|
||||
std::vector<CNodePtr> origin_nodes_topological(orders.begin(), orders.end());
|
||||
std::vector<CNodePtr> slice_allgathers;
|
||||
int64_t recompute_order_id = 0;
|
||||
for (auto &node : origin_nodes_topological) {
|
||||
if (!node->HasAttr(kAttrSliceActivation) || IsPrimitiveCNode(node, prim::kPrimTranspose) ||
|
||||
!node->has_user_data<parallel::OperatorInfo>()) {
|
||||
continue;
|
||||
}
|
||||
auto node_users = manager->node_users()[node];
|
||||
std::vector<std::pair<std::shared_ptr<AnfNode>, int>> duplicate_users;
|
||||
std::vector<std::pair<std::shared_ptr<AnfNode>, int>> forward_users;
|
||||
GroupingNextNodes(node, &duplicate_users, &forward_users);
|
||||
if (duplicate_users.empty() || forward_users.empty()) {
|
||||
continue;
|
||||
}
|
||||
InsertSliceAllGatherNode(duplicate_users, forward_users[0], node, &slice_allgathers, recompute_order_id);
|
||||
recompute_order_id++;
|
||||
}
|
||||
if (slice_allgathers.size() == 0) {
|
||||
return;
|
||||
}
|
||||
if (parallel::ParallelContext::GetInstance()->pipeline_stage_split_num() > 1) {
|
||||
int64_t current_micro = -1;
|
||||
std::vector<CNodePtr> stage_slice_allgathers;
|
||||
for (auto &slice_allgather_node : slice_allgathers) {
|
||||
if (!slice_allgather_node->HasPrimalAttr(parallel::MICRO)) {
|
||||
MS_LOG(EXCEPTION) << "In pipeline parallel mode, cannot find 'micro' attributes in node.";
|
||||
}
|
||||
int64_t micro = GetValue<int64_t>(slice_allgather_node->GetPrimalAttr(parallel::MICRO));
|
||||
if (micro > current_micro) {
|
||||
if (current_micro != -1) {
|
||||
MS_LOG(INFO) << "Insert allgather depends, micro is: " << current_micro;
|
||||
InsertAllGatherDepend(graph, stage_slice_allgathers);
|
||||
}
|
||||
stage_slice_allgathers.clear();
|
||||
stage_slice_allgathers.push_back(slice_allgather_node);
|
||||
current_micro = micro;
|
||||
} else if (micro == current_micro) {
|
||||
stage_slice_allgathers.push_back(slice_allgather_node);
|
||||
} else if (current_micro != -1) {
|
||||
MS_LOG(EXCEPTION) << "The micro number dose not match the execution orders in pipeline parallel";
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Insert last stage allgather depends, micro is: " << current_micro;
|
||||
InsertAllGatherDepend(graph, stage_slice_allgathers);
|
||||
} else {
|
||||
InsertAllGatherDepend(graph, slice_allgathers);
|
||||
}
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,28 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_SLICE_ACTIVATION_IN_RECOMPUTE_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_SLICE_ACTIVATION_IN_RECOMPUTE_H_
|
||||
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
// Automatically insert duplicated recomputed nodes.
|
||||
void SliceRecomputedActivationNodes(const FuncGraphPtr &graph);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_SLICE_ACTIVATION_IN_RECOMPUTE_H_
|
|
@ -28,7 +28,7 @@
|
|||
namespace mindspore {
|
||||
namespace parallel {
|
||||
using Args = std::vector<std::int64_t>;
|
||||
|
||||
Operator CreateStridedSliceOp(int64_t value, const Shape &begin, const Shape &end, const Shape &strides);
|
||||
class ConstructOperator {
|
||||
public:
|
||||
const int64_t DEFAULT = 0;
|
||||
|
|
|
@ -40,6 +40,7 @@
|
|||
#include "frontend/parallel/cache_embedding/cache_embedding.h"
|
||||
#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
|
||||
#include "frontend/optimizer/recompute.h"
|
||||
#include "frontend/optimizer/slice_activation_in_recompute.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "pipeline/jit/pipeline_split.h"
|
||||
#include "pipeline/pynative/pynative_execute.h"
|
||||
|
@ -594,6 +595,12 @@ bool AddRecomputationPass(const ResourcePtr &res) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool SliceRecomputeActivationPass(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
opt::SliceRecomputedActivationNodes(res->func_graph());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AddCacheEmbeddingPass(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
|
@ -734,7 +741,8 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
|
|||
{"tuple_transform", OptPassTransformGraphGroup},
|
||||
{"add_cache_embedding", AddCacheEmbeddingPass},
|
||||
{"add_recomputation", AddRecomputationPass},
|
||||
{"cse_after_recomputation", OptAfterRecomputeGroup}};
|
||||
{"cse_after_recomputation", OptAfterRecomputeGroup},
|
||||
{"slice_recompute_activation", SliceRecomputeActivationPass}};
|
||||
|
||||
std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||
{"opt_a", OptPassAGroup},
|
||||
|
|
|
@ -471,6 +471,7 @@ constexpr auto kAttrPadding = "padding";
|
|||
constexpr auto kAttrNonTask = "non_task";
|
||||
constexpr auto kAttrIsGrad = "is_grad";
|
||||
constexpr auto kAttrRecompute = "recompute";
|
||||
constexpr auto kAttrSliceActivation = "slice_activation";
|
||||
constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute";
|
||||
constexpr auto kAttrParallelDimInfo = "parallel_dim_info";
|
||||
constexpr auto kAttrParallelFusionType = "parallel_fusion_type";
|
||||
|
|
|
@ -1499,6 +1499,16 @@ class Cell(Cell_):
|
|||
for param in self.trainable_params():
|
||||
param.parallel_optimizer_comm_recompute = parallel_optimizer_comm_recompute
|
||||
|
||||
def _recompute_slice_activation(self, slice_activation=False):
|
||||
"""
|
||||
Slice the cell output which would remains in memory.
|
||||
"""
|
||||
for _, value in self._primitives.items():
|
||||
if value:
|
||||
value.add_prim_attr("slice_activation", slice_activation)
|
||||
for cell in self.cells():
|
||||
cell._recompute_slice_activation(slice_activation)
|
||||
|
||||
def _recompute(self, mode=True, output_recompute=False):
|
||||
"""
|
||||
Set the cell recomputed.
|
||||
|
@ -1554,9 +1564,11 @@ class Cell(Cell_):
|
|||
raise ValueError("Currently, the communication operator allgathers introduced by optimizer shard "
|
||||
"are not support recomputation in pipeline parallel.")
|
||||
self._parallel_optimizer_comm_recompute(kwargs['parallel_optimizer_comm_recompute'])
|
||||
if 'recompute_slice_activation' in kwargs.keys():
|
||||
self._recompute_slice_activation(kwargs['recompute_slice_activation'])
|
||||
|
||||
for key, _ in kwargs.items():
|
||||
if key not in ('mp_comm_recompute', 'parallel_optimizer_comm_recompute'):
|
||||
if key not in ('mp_comm_recompute', 'parallel_optimizer_comm_recompute', 'recompute_slice_activation'):
|
||||
raise ValueError("Recompute keyword %s is not recognized!" % key)
|
||||
|
||||
def infer_param_pipeline_stage(self):
|
||||
|
|
|
@ -78,7 +78,6 @@ class OpParallelConfig(_Config):
|
|||
Validator.check_positive_int(value, "model_parallel")
|
||||
self._model_parallel = value
|
||||
|
||||
|
||||
class _PipeLineConfig(_Config):
|
||||
r"""
|
||||
PPConfig for the setting data parallel, model parallel
|
||||
|
|
|
@ -48,7 +48,8 @@ __all__ = [
|
|||
"TransformerDecoderLayer",
|
||||
"Transformer",
|
||||
"TransformerOpParallelConfig",
|
||||
"EmbeddingOpParallelConfig"]
|
||||
"EmbeddingOpParallelConfig",
|
||||
"TransformerRecomputeConfig"]
|
||||
|
||||
|
||||
class EmbeddingOpParallelConfig(_Config):
|
||||
|
@ -112,6 +113,76 @@ class EmbeddingOpParallelConfig(_Config):
|
|||
return self._dp_mp_config
|
||||
|
||||
|
||||
class TransformerRecomputeConfig(_Config):
|
||||
r"""
|
||||
TransformerRecomputeConfig for the setting recompute attributes for encoder/decoder layers.
|
||||
|
||||
Args:
|
||||
recompute (bool): Enable recomputation of the transformer block or not. Default: False.
|
||||
parallel_optimizer_comm_recompute (bool): The model parallel way. Default: 1
|
||||
parallel_optimizer_comm_recompute (bool): Specifies whether the communication operator allgathers
|
||||
introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode.
|
||||
Default: False.
|
||||
mp_comm_recompute (bool): Specifies whether the model parallel communication operators
|
||||
in the cell are recomputed in auto parallel or semi auto parallel mode. Default: True.
|
||||
recompute_slice_activation (bool): Slice the cell output which would remains in memory. Default: False.
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> config=TransformerRecomputeConfig(recompute=True, parallel_optimizer_comm_recompute=True,
|
||||
>>> mp_comm_recompute=True, recompute_slice_activation=True)
|
||||
"""
|
||||
|
||||
def __init__(self, recompute=False, parallel_optimizer_comm_recompute=False,
|
||||
mp_comm_recompute=True, recompute_slice_activation=False):
|
||||
Validator.check_bool(recompute, "recompute")
|
||||
Validator.check_bool(parallel_optimizer_comm_recompute, "parallel_optimizer_comm_recompute")
|
||||
Validator.check_bool(mp_comm_recompute, "mp_comm_recompute")
|
||||
Validator.check_bool(recompute_slice_activation, "recompute_slice_activation")
|
||||
self._recompute = recompute
|
||||
self._parallel_optimizer_comm_recompute = parallel_optimizer_comm_recompute
|
||||
self._mp_comm_recompute = mp_comm_recompute
|
||||
self._recompute_slice_activation = recompute_slice_activation
|
||||
|
||||
@property
|
||||
def recompute(self):
|
||||
return self._recompute
|
||||
|
||||
@recompute.setter
|
||||
def recompute(self, value):
|
||||
Validator.check_bool(value, "recompute")
|
||||
self._recompute = value
|
||||
|
||||
@property
|
||||
def parallel_optimizer_comm_recompute(self):
|
||||
return self._parallel_optimizer_comm_recompute
|
||||
|
||||
@parallel_optimizer_comm_recompute.setter
|
||||
def parallel_optimizer_comm_recompute(self, value):
|
||||
Validator.check_bool(value, "parallel_optimizer_comm_recompute")
|
||||
self._parallel_optimizer_comm_recompute = value
|
||||
|
||||
@property
|
||||
def mp_comm_recompute(self):
|
||||
return self._mp_comm_recompute
|
||||
|
||||
@mp_comm_recompute.setter
|
||||
def mp_comm_recompute(self, value):
|
||||
Validator.check_bool(value, "mp_comm_recompute")
|
||||
self._mp_comm_recompute = value
|
||||
|
||||
@property
|
||||
def recompute_slice_activation(self):
|
||||
return self._recompute_slice_activation
|
||||
|
||||
@recompute_slice_activation.setter
|
||||
def recompute_slice_activation(self, value):
|
||||
Validator.check_bool(value, "recompute_slice_activation")
|
||||
self._recompute_slice_activation = value
|
||||
|
||||
_DEFALUT_TRANSFORMER_RECOMPUTE_CONFIG = TransformerRecomputeConfig()
|
||||
|
||||
class TransformerOpParallelConfig(_Config):
|
||||
r"""
|
||||
TransformerOpParallelConfig for the setting global data parallel, model parallel and fusion group.
|
||||
|
@ -131,17 +202,21 @@ class TransformerOpParallelConfig(_Config):
|
|||
micro_batch_num (int): The microe size of the batches for the pipeline training. Default: 1.
|
||||
optimizer_shard (bool): Whether to enable optimizer shard. Default False.
|
||||
gradient_aggregation_group (int): The fusion group size of the optimizer state sharding. Default: 4.
|
||||
recompute (bool): Enable recomputation of the transformer block or not. Default: False.
|
||||
recompute (Union[TransformerRecomputeConfig, bool]): The configuration of recomputation for
|
||||
the transformer block. Default: The default configuration of TransformerRecomputeConfig.
|
||||
vocab_emb_dp (bool): Shard embedding in model parallel or data parallel. Default: True.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> config=TransformerOpParallelConfig(data_parallel=1, model_parallel=1)
|
||||
>>> recompute_config=TransformerRecomputeConfig(recompute=True, parallel_optimizer_comm_recompute=True,
|
||||
>>> mp_comm_recompute=True, recompute_slice_activation=True)
|
||||
>>> config=TransformerOpParallelConfig(data_parallel=1, model_parallel=1, recompute=recompute_config)
|
||||
"""
|
||||
|
||||
def __init__(self, data_parallel=1, model_parallel=1, pipeline_stage=1, micro_batch_num=1, recompute=False,
|
||||
def __init__(self, data_parallel=1, model_parallel=1, pipeline_stage=1, micro_batch_num=1,
|
||||
recompute=_DEFALUT_TRANSFORMER_RECOMPUTE_CONFIG,
|
||||
optimizer_shard=False, gradient_aggregation_group=4, vocab_emb_dp=True):
|
||||
self.recompute = recompute
|
||||
self.optimizer_shard = optimizer_shard
|
||||
|
@ -156,7 +231,10 @@ class TransformerOpParallelConfig(_Config):
|
|||
|
||||
@recompute.setter
|
||||
def recompute(self, value):
|
||||
Validator.check_bool(value, "recompute")
|
||||
if not isinstance(value, TransformerRecomputeConfig) and not isinstance(value, bool):
|
||||
raise TypeError(f"recompute should be a TransformerRecomputeConfig/bool, but got {type(value).__name__}.")
|
||||
if isinstance(value, bool):
|
||||
logger.warning(f"TransformerRecomputeConfig is recommended as the recompute configuration type.")
|
||||
self._recompute = value
|
||||
|
||||
@property
|
||||
|
@ -1726,8 +1804,15 @@ def _get_lambda_func(total_layer=None):
|
|||
dis = max(int(layers / parallel_config.gradient_aggregation_group), 1)
|
||||
network.set_comm_fusion(int((layer_id + offset) / dis) + 1)
|
||||
# Used for enabling recomputation of the block
|
||||
if parallel_config.recompute:
|
||||
network.recompute()
|
||||
if isinstance(parallel_config.recompute, bool):
|
||||
if parallel_config.recompute:
|
||||
network.recompute()
|
||||
else:
|
||||
if parallel_config.recompute.recompute:
|
||||
network.recompute(parallel_optimizer_comm_recompute=
|
||||
parallel_config.recompute.parallel_optimizer_comm_recompute,
|
||||
mp_comm_recompute=parallel_config.recompute.mp_comm_recompute,
|
||||
recompute_slice_activation=parallel_config.recompute.recompute_slice_activation)
|
||||
|
||||
return _set_parallel_configure_for_layer
|
||||
|
||||
|
|
|
@ -27,8 +27,8 @@ class MatMulCell(nn.Cell):
|
|||
def __init__(self):
|
||||
super(MatMulCell, self).__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.matmul0 = P.MatMul()
|
||||
self.weight = Parameter(initializer("ones", [128, 64], ms.float32), name="weight")
|
||||
self.matmul0 = P.MatMul(transpose_b=True)
|
||||
self.weight = Parameter(initializer("ones", [64, 128], ms.float32), name="weight")
|
||||
self.relu = P.ReLU().shard(((1, 8),))
|
||||
def construct(self, x):
|
||||
x = self.matmul0(x, self.weight)
|
||||
|
@ -37,25 +37,26 @@ class MatMulCell(nn.Cell):
|
|||
return x
|
||||
|
||||
class DenseMutMulNet(nn.Cell):
|
||||
def __init__(self):
|
||||
def __init__(self, mp_comm_recompute=True, recompute_slice_activation=False):
|
||||
super(DenseMutMulNet, self).__init__()
|
||||
self.fc1 = nn.Dense(128, 768, activation='relu')
|
||||
self.fc2 = nn.Dense(128, 768, activation='relu')
|
||||
self.fc3 = nn.Dense(128, 768, activation='relu')
|
||||
self.fc4 = nn.Dense(768, 768, activation='relu')
|
||||
self.fc1.matmul.shard(((1, 1), (1, 8)))
|
||||
self.fc2.matmul.shard(((1, 1), (1, 8)))
|
||||
self.fc3.matmul.shard(((1, 1), (1, 8)))
|
||||
self.fc1.matmul.shard(((1, 1), (8, 1)))
|
||||
self.fc2.matmul.shard(((1, 1), (8, 1)))
|
||||
self.fc3.matmul.shard(((1, 1), (8, 1)))
|
||||
self.relu4 = nn.ReLU()
|
||||
self.relu5 = nn.ReLU()
|
||||
self.transpose = P.Transpose()
|
||||
self.matmul1 = P.MatMul()
|
||||
self.matmul2 = P.MatMul()
|
||||
self.matmul_cell = MatMulCell()
|
||||
self.fc1.recompute(mp_comm_recompute=False)
|
||||
self.fc2.recompute(mp_comm_recompute=False)
|
||||
self.fc3.recompute(mp_comm_recompute=False)
|
||||
self.matmul_cell.recompute(mp_comm_recompute=False)
|
||||
self.fc1.recompute(mp_comm_recompute=mp_comm_recompute, recompute_slice_activation=recompute_slice_activation)
|
||||
self.fc2.recompute(mp_comm_recompute=mp_comm_recompute, recompute_slice_activation=recompute_slice_activation)
|
||||
self.fc3.recompute(mp_comm_recompute=mp_comm_recompute, recompute_slice_activation=recompute_slice_activation)
|
||||
self.matmul_cell.recompute(mp_comm_recompute=mp_comm_recompute,
|
||||
recompute_slice_activation=recompute_slice_activation)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.matmul_cell(x)
|
||||
|
@ -68,16 +69,40 @@ class DenseMutMulNet(nn.Cell):
|
|||
s = self.fc4(s)
|
||||
return s
|
||||
|
||||
|
||||
def test_dmnet_train_step():
|
||||
def compile_net(mp_comm_recompute, recompute_slice_activation):
|
||||
context.reset_auto_parallel_context()
|
||||
_Context().set_backend_policy("vm")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8)
|
||||
input_ = Tensor(np.ones([64, 128]).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.zeros([32, 768]).astype(np.float32))
|
||||
net = train_step_with_loss_warp(DenseMutMulNet())
|
||||
net = train_step_with_loss_warp(DenseMutMulNet(mp_comm_recompute=mp_comm_recompute,
|
||||
recompute_slice_activation=recompute_slice_activation))
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
_cell_graph_executor.compile(net, input_, label)
|
||||
_Context().set_backend_policy("ge")
|
||||
|
||||
def test_dmnet_train_step_mp_recompute():
|
||||
"""
|
||||
Feature: test recompute interface.
|
||||
Description: test model parallel communication not recompute.
|
||||
Expectation: compile without error.
|
||||
"""
|
||||
compile_net(False, False)
|
||||
|
||||
def test_dmnet_train_step_recompute_activation_slice():
|
||||
"""
|
||||
Feature: test recompute interface.
|
||||
Description: test slicing recompute cell output.
|
||||
Expectation: compile without error.
|
||||
"""
|
||||
compile_net(True, True)
|
||||
|
||||
def test_dmnet_train_step_mp_recompute_recompute_activation_slice():
|
||||
"""
|
||||
Feature: test recompute interface.
|
||||
Description: test model parallel communication not recompute and slicing recompute cell output.
|
||||
Expectation: compile without error.
|
||||
"""
|
||||
compile_net(False, True)
|
||||
|
|
|
@ -652,9 +652,9 @@ def test_transformer_parallel_config():
|
|||
with pytest.raises(TypeError):
|
||||
parallel_test_config.recompute = 1
|
||||
|
||||
parallel_test_config.recompute = False
|
||||
parallel_test_config.recompute.recompute = False
|
||||
|
||||
assert not parallel_test_config.recompute
|
||||
assert not parallel_test_config.recompute.recompute
|
||||
|
||||
|
||||
def test_parallel_config():
|
||||
|
|
Loading…
Reference in New Issue