!27928 recompute_node_add_micro_attr

Merge pull request !27928 from yao_yf/recompute_node_add_micro_atte
This commit is contained in:
i-robot 2021-12-31 02:41:01 +00:00 committed by Gitee
commit e91f2286a1
12 changed files with 208 additions and 23 deletions

View File

@ -25,6 +25,7 @@
#include "base/core_ops.h"
#include "utils/utils.h"
#include "utils/shape_utils.h"
#include "frontend/parallel/context.h"
#include "runtime/device/kernel_info.h"
#include "runtime/device/device_address.h"
#include "backend/optimizer/common/helper.h"
@ -521,6 +522,11 @@ void AnfRuntimeAlgorithm::CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr
MS_EXCEPTION_IF_NULL(from_primitive);
auto to_primitive = AnfAlgo::GetCNodePrimitive(to);
MS_EXCEPTION_IF_NULL(to_primitive);
auto from_cnode = from->cast<CNodePtr>();
auto to_cnode = to->cast<CNodePtr>();
if (from_cnode->HasPrimalAttr(kAttrMicro)) {
to_cnode->AddPrimalAttr(kAttrMicro, from_cnode->GetPrimalAttr(kAttrMicro));
}
(void)to_primitive->SetAttrs(from_primitive->attrs());
}
@ -1809,10 +1815,13 @@ std::vector<CNodePtr> DelayExecNode(const std::vector<CNodePtr> &nodes, const st
void AnfRuntimeAlgorithm::ReorderExecList(NotNull<std::vector<CNodePtr> *> node_list) {
std::vector<CNodePtr> result;
std::copy(node_list->begin(), node_list->end(), std::back_inserter(result));
result = DelayExecNode(result, "TransData", true);
result = DelayExecNode(result, "Cast", true);
result = DelayExecNode(result, "AdamApplyOneWithDecay", false);
result = DelayExecNode(result, "AdamApplyOne", false);
result = DelayExecNode(result, kTransDataOpName, true);
result = DelayExecNode(result, kCastOpName, true);
result = DelayExecNode(result, kAdamApplyOneWithDecayOpName, false);
result = DelayExecNode(result, kAdamApplyOneOpName, false);
if (parallel::ParallelContext::GetInstance()->pipeline_stage_split_num() > 1) {
result = DelayExecNode(result, kDropoutGenMaskOpName, true);
}
node_list->clear();
std::copy(result.begin(), result.end(), std::back_inserter(*node_list));
}

View File

@ -335,6 +335,9 @@ CNodePtr CreateNewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &orig
recomputed_node->AddAttr(kAttrNeedCseAfterRecompute, MakeValue(true));
recomputed_node->set_abstract(origin_node->abstract());
recomputed_node->set_scope(origin_node->scope());
if (origin_node->HasPrimalAttr(kAttrMicro)) {
recomputed_node->AddPrimalAttr(kAttrMicro, origin_node->GetPrimalAttr(kAttrMicro));
}
return recomputed_node;
}

View File

@ -27,8 +27,9 @@
namespace mindspore {
namespace parallel {
DeviceManagerPtr g_device_manager = nullptr;
bool InitDevice(int64_t device_num, int64_t global_rank, const std::string &backend,
const std::vector<int64_t> &stage) {
bool CheckDeviceConfig(int64_t device_num, int64_t global_rank, const std::string &backend,
const std::vector<int64_t> &stage) {
if (device_num <= 0) {
MS_LOG(ERROR) << "The context configuration parameter 'device_num' must be positive, "
"but got the value of device_num: "
@ -46,10 +47,10 @@ bool InitDevice(int64_t device_num, int64_t global_rank, const std::string &back
<< ", but got the value of device_num: " << device_num;
return false;
}
// 'device_num_converted' must be the power of 2
if ((LongToUlong(device_num) & LongToUlong(device_num - 1)) != 0) {
MS_LOG(ERROR) << "The context configuration parameter device_num' must be the power of 2, "
"but got the value of device_num: "
// 'device_num_converted' must be divisible by 8
if (device_num % DEVICE_NUM_PER_SERVER != 0 && device_num != 1 && device_num != 2 && device_num != 4) {
MS_LOG(ERROR) << "The context configuration parameter device_num' must be divisible by 8, "
"or equal to 1, 2 or 4, but got the value of device_num: "
<< device_num;
return false;
}
@ -69,6 +70,14 @@ bool InitDevice(int64_t device_num, int64_t global_rank, const std::string &back
MS_LOG(ERROR) << "The size of parameter 'stage' must be positive, but got the size of stage is empty.";
return false;
}
return true;
}
bool InitDevice(int64_t device_num, int64_t global_rank, const std::string &backend,
const std::vector<int64_t> &stage) {
if (!CheckDeviceConfig(device_num, global_rank, backend, stage)) {
return false;
}
RankList devices, stage_map;
for (int64_t i = 0; i < device_num; ++i) {

View File

@ -37,6 +37,7 @@ namespace mindspore {
namespace parallel {
#define MAX_DEVICE_NUM 4096
constexpr size_t DEVICE_NUM_PER_SERVER = 8;
constexpr char HCCL_BACKEND[] = "hccl";
constexpr char NCCL_BACKEND[] = "nccl";
constexpr char UNDEFINED_BACKEND[] = "undefined_backend";

View File

@ -114,6 +114,10 @@ Status OperatorInfo::CheckStrategyValue(const StrategyPtr &strategy, const Shape
}
if ((LongToUlong(strategy_value) & LongToUlong(strategy_value - 1)) != 0) {
if ((g_device_manager->DeviceNum() & (g_device_manager->DeviceNum() - 1)) != 0) {
MS_LOG(WARNING) << "The device num is not the power of 2, thus do not check the strategy as power of 2";
return SUCCESS;
}
if (is_auto_parallel_) {
MS_LOG(DEBUG) << name_ << ": The strategy is " << StrategyToString(stra)
<< ", the value of strategy must be the power of 2, but get " << strategy_value;

View File

@ -3152,11 +3152,22 @@ static void HandleDataParallel() {
}
}
static void PipelinePreProcess(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager,
const std::vector<AnfNodePtr> &all_nodes) {
auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
if (pipeline_stages > 1) {
HandleMicroBatch(all_nodes, manager);
ParameterStartNode(all_nodes, manager);
LastStageEndNode(all_nodes, manager, root);
}
}
static void PipelinePostProcess(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
if (pipeline_stages > 1) {
AddVirtualAssignAdd(root);
HandleReceiveParam(root, all_nodes);
LabelGenMaskMicro(root);
}
}
@ -3207,12 +3218,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
if (pipeline_stages <= 1 && ParallelInit() != SUCCESS) {
MS_LOG(EXCEPTION) << "Parallel init failed";
}
if (pipeline_stages > 1) {
HandleMicroBatch(all_nodes, manager);
ParameterStartNode(all_nodes, manager);
LastStageEndNode(all_nodes, manager, root);
}
PipelinePreProcess(root, manager, all_nodes);
// mark the forward cnodes, parallel only care these nodes
MarkForwardCNode(root);
@ -3234,6 +3240,8 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
ReshapeInit(all_nodes);
}
SetCastForParamNotRecompute(all_nodes);
HandleRootReshapeAndSaveStrategy(all_nodes);
HandleForwardMakeTupleAndMakeList(all_nodes);

View File

@ -422,6 +422,62 @@ AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, cons
auto new_node = node->func_graph()->NewCNode({NewValueNode(prim), pre_node, type_node});
new_node->set_abstract(node->abstract());
return new_node;
} // namespace parallel
}
AnfNodePtr RealInputNode(const CNodePtr cnode, size_t index) {
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() <= index) {
MS_LOG(EXCEPTION) << "cnode inputs size: " << cnode->size() << " is less equal index: " << index;
}
auto input0 = cnode->input(index);
if (!input0->isa<CNode>()) {
return input0;
}
auto prim = GetCNodePrimitive(input0);
MS_EXCEPTION_IF_NULL(prim);
while (prim->name() == LOAD || prim->name() == DEPEND || prim->name() == UPDATESTATE) {
if (prim->name() == LOAD || prim->name() == DEPEND) {
input0 = input0->cast<CNodePtr>()->input(1);
} else if (prim->name() == UPDATESTATE) {
input0 = input0->cast<CNodePtr>()->input(2);
}
if (!input0->isa<CNode>()) {
return input0;
}
prim = GetCNodePrimitive(input0);
MS_EXCEPTION_IF_NULL(prim);
}
return input0;
}
void LabelGenMaskMicro(const FuncGraphPtr &root) {
AnfNodePtr ret = root->get_return();
MS_EXCEPTION_IF_NULL(ret);
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
for (auto &node : all_nodes) {
if (IsPrimitiveCNode(node, prim::kPrimDropoutDoMask)) {
auto gen_mask_node = RealInputNode(node->cast<CNodePtr>(), 2);
if (gen_mask_node->isa<CNode>()) {
gen_mask_node->cast<CNodePtr>()->set_primal_attrs(node->cast<CNodePtr>()->primal_attrs());
}
}
}
}
void SetCastForParamNotRecompute(const std::vector<AnfNodePtr> &all_nodes) {
for (const auto &node : all_nodes) {
if (!IsPrimitiveCNode(node, prim::kPrimCast)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
auto cast_input = RealInputNode(cnode, 1);
if (cast_input->isa<Parameter>()) {
MS_LOG(INFO) << "Cast for parameter no needs recompute to avoid redundant trans_data operator";
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)->cast<ValueNodePtr>());
prim->AddAttr("recompute", MakeValue(false));
}
}
}
} // namespace parallel
} // namespace mindspore

View File

@ -27,19 +27,25 @@ namespace mindspore {
namespace parallel {
const int64_t TWO_INPUT_SIZE = 2;
// common method
bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name);
bool IsParallelCareNode(const CNodePtr &cnode);
AnfNodePtr RealInputNode(const CNodePtr cnode, size_t index);
Shapes GetNodeShape(const AnfNodePtr &node);
RankList FindCommonMirrorGroup(const FuncGraphPtr &root);
std::string CreateInstanceName(const CNodePtr &node, size_t index);
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input);
std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
const CNodePtr &node);
std::string CreateInstanceName(const CNodePtr &node, size_t index);
// for specific scenarios
RankList FindCommonMirrorGroup(const FuncGraphPtr &root);
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input);
void SetStridedSliceSplitStrategy(const std::vector<AnfNodePtr> &all_nodes);
AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, const NodeUsersMap &node_user_map,
const TypePtr &compute_node_type);
AnfNodePtr GetChildCastNode(const CNodePtr &cnode_ptr, const NodeUsersMap &node_users_map);
TypePtr FindChildCastWithFP32ToFP16(const CNodePtr &cnode_ptr, const NodeUsersMap &node_users_map);
void LabelGenMaskMicro(const FuncGraphPtr &root);
void SetCastForParamNotRecompute(const std::vector<AnfNodePtr> &all_nodes);
} // namespace parallel
} // namespace mindspore

View File

@ -271,6 +271,7 @@ void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr)
InsertStreamActive(graph_ptr);
InsertEventForHcomParallel(graph_ptr);
InsertEventForIndependentParallel(graph_ptr);
InsertEventForMicroBatchIndependent(graph_ptr);
GetIndependentMaxTarget(graph_ptr);
InsertCtrlForIndependentParallel(graph_ptr);
AdjustAtomicAddrCleanOrder(graph_ptr);
@ -2597,6 +2598,92 @@ void AscendStreamAssign::AdjustAtomicAddrCleanOrder(const NotNull<KernelGraphPtr
}
graph_ptr->set_execution_order(update_orders);
}
CNodePtr FindNextGenMask(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr do_mask_cnode) {
auto &exec_order = graph_ptr->execution_order();
auto iter = std::find(exec_order.begin(), exec_order.end(), do_mask_cnode);
for (; iter != exec_order.end(); iter++) {
auto cnode = *iter;
if ((AnfAlgo::GetCNodeName(cnode) != kDropoutGenMaskOpName &&
AnfAlgo::GetCNodeName(cnode) != kDropoutGenMaskV3OpName) ||
!cnode->HasPrimalAttr(kAttrMicro)) {
continue;
}
return cnode;
}
return nullptr;
}
void AscendStreamAssign::InsertEventForMicroBatchIndependent(const NotNull<KernelGraphPtr> &graph_ptr) {
if (parallel::ParallelContext::GetInstance()->pipeline_stage_split_num() <= 1) {
return;
}
std::map<CNodePtr, CNodePtr> node_send_map;
std::map<CNodePtr, CNodePtr> node_recv_map;
std::map<size_t, CNodePtr> micro_last_cnode_map;
AscendStreamMng &resource_manager = AscendStreamMng::GetInstance();
auto &exec_order = graph_ptr->execution_order();
for (auto &cnode : exec_order) {
if (AnfAlgo::GetCNodeName(cnode) != kDropoutDoMaskOpName &&
AnfAlgo::GetCNodeName(cnode) != kDropoutDoMaskV3OpName) {
continue;
}
if (!cnode->HasPrimalAttr(kAttrMicro)) {
MS_LOG(WARNING) << "Node doesn't have the attr [micro], node: " << cnode->fullname_with_scope();
continue;
}
auto micro_ptr = cnode->GetPrimalAttr(kAttrMicro);
auto micro_value = GetValue<int64_t>(micro_ptr);
micro_last_cnode_map[micro_value] = cnode;
}
for (auto &micro_cnode_item : micro_last_cnode_map) {
auto cnode = micro_cnode_item.second;
auto micro_batch = micro_cnode_item.first;
MS_LOG(INFO) << "Micro: " << micro_batch << ", last DropoutDoMask: " << cnode->fullname_with_scope();
auto next_gen_mask = FindNextGenMask(graph_ptr, cnode);
if (next_gen_mask == nullptr) {
MS_LOG(INFO) << "Node doesn't have the next DropoutGenMask, node: " << cnode->fullname_with_scope()
<< ", micro value: " << micro_batch;
continue;
}
MS_LOG(INFO) << "Insert send after node: " << cnode->fullname_with_scope()
<< ", insert recv before node: " << next_gen_mask->fullname_with_scope();
uint32_t cur_event_id = resource_manager.ApplyNewEvent();
CNodePtr send_cnode = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId((cnode)));
CNodePtr recv_cnode = CreateRecvApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(next_gen_mask));
node_send_map[cnode] = send_cnode;
node_recv_map[next_gen_mask] = recv_cnode;
}
MS_LOG(INFO) << "Print execution order before inserting event between DropoutDoMask and DropoutGenMask.";
graph_ptr->PrintGraphExecuteOrder();
std::vector<CNodePtr> new_exec_order;
for (auto &cnode : exec_order) {
auto cnode_name = AnfAlgo::GetCNodeName(cnode);
if (cnode_name == kDropoutDoMaskOpName || cnode_name == kDropoutDoMaskV3OpName) {
auto send_iter = node_send_map.find(cnode);
if (send_iter != node_send_map.end()) {
new_exec_order.push_back(cnode);
new_exec_order.push_back((*send_iter).second);
continue;
}
}
if (cnode_name == kDropoutGenMaskOpName || cnode_name == kDropoutGenMaskV3OpName) {
auto recv_iter = node_recv_map.find(cnode);
if (recv_iter != node_recv_map.end()) {
new_exec_order.push_back((*recv_iter).second);
new_exec_order.push_back(cnode);
continue;
}
}
new_exec_order.push_back(cnode);
}
graph_ptr->set_execution_order(new_exec_order);
MS_LOG(INFO) << "Print execution order after inserting event between DropoutDoMask and DropoutGenMask.";
graph_ptr->PrintGraphExecuteOrder();
}
} // namespace ascend
} // namespace device
} // namespace mindspore

View File

@ -156,6 +156,8 @@ class AscendStreamAssign {
void GetAllGraphID(const NotNull<KernelGraphPtr> &graph_ptr, std::vector<uint32_t> *graphs_id);
void GraphLoopSync(const NotNull<KernelGraphPtr> &root_graph, uint32_t graph_id);
void InsertEventForMicroBatchIndependent(const NotNull<KernelGraphPtr> &graph_ptr);
bool independent_stream_activated_{false};
bool hcom_stream_activated_{false};
bool loop_sink_{false};

View File

@ -514,6 +514,7 @@ constexpr auto kAttrFuncType = "func_type";
constexpr auto kAttrCustAicpu = "cust_aicpu";
constexpr auto kAttrIsInternalOutputNopNode = "is_internal_output_nop_node";
constexpr auto kAttrIsUBFusionOp = "is_ub_fusion_op";
constexpr auto kAttrMicro = "micro";
// custom operator func type
constexpr auto kCustomTypeAOT = "aot";

View File

@ -190,8 +190,7 @@ class TransformerRecomputeConfig(_Config):
self._recompute_slice_activation = value
_DEFALUT_TRANSFORMER_RECOMPUTE_CONFIG = TransformerRecomputeConfig()
default_transformer_recompute_config = TransformerRecomputeConfig()
class TransformerOpParallelConfig(_Config):
r"""
@ -231,7 +230,7 @@ class TransformerOpParallelConfig(_Config):
"""
def __init__(self, data_parallel=1, model_parallel=1, pipeline_stage=1, micro_batch_num=1,
recompute=_DEFALUT_TRANSFORMER_RECOMPUTE_CONFIG,
recompute=default_transformer_recompute_config,
optimizer_shard=False, gradient_aggregation_group=4, vocab_emb_dp=True):
self.recompute = recompute
self.optimizer_shard = optimizer_shard