forked from mindspore-Ecosystem/mindspore
!27928 recompute_node_add_micro_attr
Merge pull request !27928 from yao_yf/recompute_node_add_micro_atte
This commit is contained in:
commit
e91f2286a1
|
@ -25,6 +25,7 @@
|
|||
#include "base/core_ops.h"
|
||||
#include "utils/utils.h"
|
||||
#include "utils/shape_utils.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "runtime/device/device_address.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
@ -521,6 +522,11 @@ void AnfRuntimeAlgorithm::CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr
|
|||
MS_EXCEPTION_IF_NULL(from_primitive);
|
||||
auto to_primitive = AnfAlgo::GetCNodePrimitive(to);
|
||||
MS_EXCEPTION_IF_NULL(to_primitive);
|
||||
auto from_cnode = from->cast<CNodePtr>();
|
||||
auto to_cnode = to->cast<CNodePtr>();
|
||||
if (from_cnode->HasPrimalAttr(kAttrMicro)) {
|
||||
to_cnode->AddPrimalAttr(kAttrMicro, from_cnode->GetPrimalAttr(kAttrMicro));
|
||||
}
|
||||
(void)to_primitive->SetAttrs(from_primitive->attrs());
|
||||
}
|
||||
|
||||
|
@ -1809,10 +1815,13 @@ std::vector<CNodePtr> DelayExecNode(const std::vector<CNodePtr> &nodes, const st
|
|||
void AnfRuntimeAlgorithm::ReorderExecList(NotNull<std::vector<CNodePtr> *> node_list) {
|
||||
std::vector<CNodePtr> result;
|
||||
std::copy(node_list->begin(), node_list->end(), std::back_inserter(result));
|
||||
result = DelayExecNode(result, "TransData", true);
|
||||
result = DelayExecNode(result, "Cast", true);
|
||||
result = DelayExecNode(result, "AdamApplyOneWithDecay", false);
|
||||
result = DelayExecNode(result, "AdamApplyOne", false);
|
||||
result = DelayExecNode(result, kTransDataOpName, true);
|
||||
result = DelayExecNode(result, kCastOpName, true);
|
||||
result = DelayExecNode(result, kAdamApplyOneWithDecayOpName, false);
|
||||
result = DelayExecNode(result, kAdamApplyOneOpName, false);
|
||||
if (parallel::ParallelContext::GetInstance()->pipeline_stage_split_num() > 1) {
|
||||
result = DelayExecNode(result, kDropoutGenMaskOpName, true);
|
||||
}
|
||||
node_list->clear();
|
||||
std::copy(result.begin(), result.end(), std::back_inserter(*node_list));
|
||||
}
|
||||
|
|
|
@ -335,6 +335,9 @@ CNodePtr CreateNewRecomputedNode(const FuncGraphPtr &graph, const CNodePtr &orig
|
|||
recomputed_node->AddAttr(kAttrNeedCseAfterRecompute, MakeValue(true));
|
||||
recomputed_node->set_abstract(origin_node->abstract());
|
||||
recomputed_node->set_scope(origin_node->scope());
|
||||
if (origin_node->HasPrimalAttr(kAttrMicro)) {
|
||||
recomputed_node->AddPrimalAttr(kAttrMicro, origin_node->GetPrimalAttr(kAttrMicro));
|
||||
}
|
||||
return recomputed_node;
|
||||
}
|
||||
|
||||
|
|
|
@ -27,8 +27,9 @@
|
|||
namespace mindspore {
|
||||
namespace parallel {
|
||||
DeviceManagerPtr g_device_manager = nullptr;
|
||||
bool InitDevice(int64_t device_num, int64_t global_rank, const std::string &backend,
|
||||
const std::vector<int64_t> &stage) {
|
||||
|
||||
bool CheckDeviceConfig(int64_t device_num, int64_t global_rank, const std::string &backend,
|
||||
const std::vector<int64_t> &stage) {
|
||||
if (device_num <= 0) {
|
||||
MS_LOG(ERROR) << "The context configuration parameter 'device_num' must be positive, "
|
||||
"but got the value of device_num: "
|
||||
|
@ -46,10 +47,10 @@ bool InitDevice(int64_t device_num, int64_t global_rank, const std::string &back
|
|||
<< ", but got the value of device_num: " << device_num;
|
||||
return false;
|
||||
}
|
||||
// 'device_num_converted' must be the power of 2
|
||||
if ((LongToUlong(device_num) & LongToUlong(device_num - 1)) != 0) {
|
||||
MS_LOG(ERROR) << "The context configuration parameter device_num' must be the power of 2, "
|
||||
"but got the value of device_num: "
|
||||
// 'device_num_converted' must be divisible by 8
|
||||
if (device_num % DEVICE_NUM_PER_SERVER != 0 && device_num != 1 && device_num != 2 && device_num != 4) {
|
||||
MS_LOG(ERROR) << "The context configuration parameter device_num' must be divisible by 8, "
|
||||
"or equal to 1, 2 or 4, but got the value of device_num: "
|
||||
<< device_num;
|
||||
return false;
|
||||
}
|
||||
|
@ -69,6 +70,14 @@ bool InitDevice(int64_t device_num, int64_t global_rank, const std::string &back
|
|||
MS_LOG(ERROR) << "The size of parameter 'stage' must be positive, but got the size of stage is empty.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool InitDevice(int64_t device_num, int64_t global_rank, const std::string &backend,
|
||||
const std::vector<int64_t> &stage) {
|
||||
if (!CheckDeviceConfig(device_num, global_rank, backend, stage)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
RankList devices, stage_map;
|
||||
for (int64_t i = 0; i < device_num; ++i) {
|
||||
|
|
|
@ -37,6 +37,7 @@ namespace mindspore {
|
|||
namespace parallel {
|
||||
#define MAX_DEVICE_NUM 4096
|
||||
|
||||
constexpr size_t DEVICE_NUM_PER_SERVER = 8;
|
||||
constexpr char HCCL_BACKEND[] = "hccl";
|
||||
constexpr char NCCL_BACKEND[] = "nccl";
|
||||
constexpr char UNDEFINED_BACKEND[] = "undefined_backend";
|
||||
|
|
|
@ -114,6 +114,10 @@ Status OperatorInfo::CheckStrategyValue(const StrategyPtr &strategy, const Shape
|
|||
}
|
||||
|
||||
if ((LongToUlong(strategy_value) & LongToUlong(strategy_value - 1)) != 0) {
|
||||
if ((g_device_manager->DeviceNum() & (g_device_manager->DeviceNum() - 1)) != 0) {
|
||||
MS_LOG(WARNING) << "The device num is not the power of 2, thus do not check the strategy as power of 2";
|
||||
return SUCCESS;
|
||||
}
|
||||
if (is_auto_parallel_) {
|
||||
MS_LOG(DEBUG) << name_ << ": The strategy is " << StrategyToString(stra)
|
||||
<< ", the value of strategy must be the power of 2, but get " << strategy_value;
|
||||
|
|
|
@ -3152,11 +3152,22 @@ static void HandleDataParallel() {
|
|||
}
|
||||
}
|
||||
|
||||
static void PipelinePreProcess(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager,
|
||||
const std::vector<AnfNodePtr> &all_nodes) {
|
||||
auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
|
||||
if (pipeline_stages > 1) {
|
||||
HandleMicroBatch(all_nodes, manager);
|
||||
ParameterStartNode(all_nodes, manager);
|
||||
LastStageEndNode(all_nodes, manager, root);
|
||||
}
|
||||
}
|
||||
|
||||
static void PipelinePostProcess(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
|
||||
auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
|
||||
if (pipeline_stages > 1) {
|
||||
AddVirtualAssignAdd(root);
|
||||
HandleReceiveParam(root, all_nodes);
|
||||
LabelGenMaskMicro(root);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3207,12 +3218,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
|||
if (pipeline_stages <= 1 && ParallelInit() != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Parallel init failed";
|
||||
}
|
||||
|
||||
if (pipeline_stages > 1) {
|
||||
HandleMicroBatch(all_nodes, manager);
|
||||
ParameterStartNode(all_nodes, manager);
|
||||
LastStageEndNode(all_nodes, manager, root);
|
||||
}
|
||||
PipelinePreProcess(root, manager, all_nodes);
|
||||
|
||||
// mark the forward cnodes, parallel only care these nodes
|
||||
MarkForwardCNode(root);
|
||||
|
@ -3234,6 +3240,8 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
|||
ReshapeInit(all_nodes);
|
||||
}
|
||||
|
||||
SetCastForParamNotRecompute(all_nodes);
|
||||
|
||||
HandleRootReshapeAndSaveStrategy(all_nodes);
|
||||
|
||||
HandleForwardMakeTupleAndMakeList(all_nodes);
|
||||
|
|
|
@ -422,6 +422,62 @@ AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, cons
|
|||
auto new_node = node->func_graph()->NewCNode({NewValueNode(prim), pre_node, type_node});
|
||||
new_node->set_abstract(node->abstract());
|
||||
return new_node;
|
||||
} // namespace parallel
|
||||
}
|
||||
|
||||
AnfNodePtr RealInputNode(const CNodePtr cnode, size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->size() <= index) {
|
||||
MS_LOG(EXCEPTION) << "cnode inputs size: " << cnode->size() << " is less equal index: " << index;
|
||||
}
|
||||
auto input0 = cnode->input(index);
|
||||
if (!input0->isa<CNode>()) {
|
||||
return input0;
|
||||
}
|
||||
auto prim = GetCNodePrimitive(input0);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
while (prim->name() == LOAD || prim->name() == DEPEND || prim->name() == UPDATESTATE) {
|
||||
if (prim->name() == LOAD || prim->name() == DEPEND) {
|
||||
input0 = input0->cast<CNodePtr>()->input(1);
|
||||
} else if (prim->name() == UPDATESTATE) {
|
||||
input0 = input0->cast<CNodePtr>()->input(2);
|
||||
}
|
||||
if (!input0->isa<CNode>()) {
|
||||
return input0;
|
||||
}
|
||||
prim = GetCNodePrimitive(input0);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
}
|
||||
return input0;
|
||||
}
|
||||
|
||||
void LabelGenMaskMicro(const FuncGraphPtr &root) {
|
||||
AnfNodePtr ret = root->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
|
||||
for (auto &node : all_nodes) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimDropoutDoMask)) {
|
||||
auto gen_mask_node = RealInputNode(node->cast<CNodePtr>(), 2);
|
||||
if (gen_mask_node->isa<CNode>()) {
|
||||
gen_mask_node->cast<CNodePtr>()->set_primal_attrs(node->cast<CNodePtr>()->primal_attrs());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SetCastForParamNotRecompute(const std::vector<AnfNodePtr> &all_nodes) {
|
||||
for (const auto &node : all_nodes) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimCast)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto cast_input = RealInputNode(cnode, 1);
|
||||
if (cast_input->isa<Parameter>()) {
|
||||
MS_LOG(INFO) << "Cast for parameter no needs recompute to avoid redundant trans_data operator";
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)->cast<ValueNodePtr>());
|
||||
prim->AddAttr("recompute", MakeValue(false));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,19 +27,25 @@ namespace mindspore {
|
|||
namespace parallel {
|
||||
const int64_t TWO_INPUT_SIZE = 2;
|
||||
|
||||
// common method
|
||||
bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name);
|
||||
bool IsParallelCareNode(const CNodePtr &cnode);
|
||||
AnfNodePtr RealInputNode(const CNodePtr cnode, size_t index);
|
||||
Shapes GetNodeShape(const AnfNodePtr &node);
|
||||
RankList FindCommonMirrorGroup(const FuncGraphPtr &root);
|
||||
std::string CreateInstanceName(const CNodePtr &node, size_t index);
|
||||
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input);
|
||||
std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
|
||||
const CNodePtr &node);
|
||||
std::string CreateInstanceName(const CNodePtr &node, size_t index);
|
||||
|
||||
// for specific scenarios
|
||||
RankList FindCommonMirrorGroup(const FuncGraphPtr &root);
|
||||
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input);
|
||||
void SetStridedSliceSplitStrategy(const std::vector<AnfNodePtr> &all_nodes);
|
||||
AnfNodePtr CreateFP16Cast(const CNodePtr &node, const AnfNodePtr &pre_node, const NodeUsersMap &node_user_map,
|
||||
const TypePtr &compute_node_type);
|
||||
AnfNodePtr GetChildCastNode(const CNodePtr &cnode_ptr, const NodeUsersMap &node_users_map);
|
||||
TypePtr FindChildCastWithFP32ToFP16(const CNodePtr &cnode_ptr, const NodeUsersMap &node_users_map);
|
||||
void LabelGenMaskMicro(const FuncGraphPtr &root);
|
||||
void SetCastForParamNotRecompute(const std::vector<AnfNodePtr> &all_nodes);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -271,6 +271,7 @@ void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr)
|
|||
InsertStreamActive(graph_ptr);
|
||||
InsertEventForHcomParallel(graph_ptr);
|
||||
InsertEventForIndependentParallel(graph_ptr);
|
||||
InsertEventForMicroBatchIndependent(graph_ptr);
|
||||
GetIndependentMaxTarget(graph_ptr);
|
||||
InsertCtrlForIndependentParallel(graph_ptr);
|
||||
AdjustAtomicAddrCleanOrder(graph_ptr);
|
||||
|
@ -2597,6 +2598,92 @@ void AscendStreamAssign::AdjustAtomicAddrCleanOrder(const NotNull<KernelGraphPtr
|
|||
}
|
||||
graph_ptr->set_execution_order(update_orders);
|
||||
}
|
||||
|
||||
CNodePtr FindNextGenMask(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr do_mask_cnode) {
|
||||
auto &exec_order = graph_ptr->execution_order();
|
||||
auto iter = std::find(exec_order.begin(), exec_order.end(), do_mask_cnode);
|
||||
for (; iter != exec_order.end(); iter++) {
|
||||
auto cnode = *iter;
|
||||
if ((AnfAlgo::GetCNodeName(cnode) != kDropoutGenMaskOpName &&
|
||||
AnfAlgo::GetCNodeName(cnode) != kDropoutGenMaskV3OpName) ||
|
||||
!cnode->HasPrimalAttr(kAttrMicro)) {
|
||||
continue;
|
||||
}
|
||||
return cnode;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void AscendStreamAssign::InsertEventForMicroBatchIndependent(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||
if (parallel::ParallelContext::GetInstance()->pipeline_stage_split_num() <= 1) {
|
||||
return;
|
||||
}
|
||||
std::map<CNodePtr, CNodePtr> node_send_map;
|
||||
std::map<CNodePtr, CNodePtr> node_recv_map;
|
||||
std::map<size_t, CNodePtr> micro_last_cnode_map;
|
||||
AscendStreamMng &resource_manager = AscendStreamMng::GetInstance();
|
||||
|
||||
auto &exec_order = graph_ptr->execution_order();
|
||||
for (auto &cnode : exec_order) {
|
||||
if (AnfAlgo::GetCNodeName(cnode) != kDropoutDoMaskOpName &&
|
||||
AnfAlgo::GetCNodeName(cnode) != kDropoutDoMaskV3OpName) {
|
||||
continue;
|
||||
}
|
||||
if (!cnode->HasPrimalAttr(kAttrMicro)) {
|
||||
MS_LOG(WARNING) << "Node doesn't have the attr [micro], node: " << cnode->fullname_with_scope();
|
||||
continue;
|
||||
}
|
||||
auto micro_ptr = cnode->GetPrimalAttr(kAttrMicro);
|
||||
auto micro_value = GetValue<int64_t>(micro_ptr);
|
||||
micro_last_cnode_map[micro_value] = cnode;
|
||||
}
|
||||
|
||||
for (auto µ_cnode_item : micro_last_cnode_map) {
|
||||
auto cnode = micro_cnode_item.second;
|
||||
auto micro_batch = micro_cnode_item.first;
|
||||
MS_LOG(INFO) << "Micro: " << micro_batch << ", last DropoutDoMask: " << cnode->fullname_with_scope();
|
||||
auto next_gen_mask = FindNextGenMask(graph_ptr, cnode);
|
||||
if (next_gen_mask == nullptr) {
|
||||
MS_LOG(INFO) << "Node doesn't have the next DropoutGenMask, node: " << cnode->fullname_with_scope()
|
||||
<< ", micro value: " << micro_batch;
|
||||
continue;
|
||||
}
|
||||
MS_LOG(INFO) << "Insert send after node: " << cnode->fullname_with_scope()
|
||||
<< ", insert recv before node: " << next_gen_mask->fullname_with_scope();
|
||||
uint32_t cur_event_id = resource_manager.ApplyNewEvent();
|
||||
CNodePtr send_cnode = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId((cnode)));
|
||||
CNodePtr recv_cnode = CreateRecvApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(next_gen_mask));
|
||||
node_send_map[cnode] = send_cnode;
|
||||
node_recv_map[next_gen_mask] = recv_cnode;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Print execution order before inserting event between DropoutDoMask and DropoutGenMask.";
|
||||
graph_ptr->PrintGraphExecuteOrder();
|
||||
std::vector<CNodePtr> new_exec_order;
|
||||
for (auto &cnode : exec_order) {
|
||||
auto cnode_name = AnfAlgo::GetCNodeName(cnode);
|
||||
if (cnode_name == kDropoutDoMaskOpName || cnode_name == kDropoutDoMaskV3OpName) {
|
||||
auto send_iter = node_send_map.find(cnode);
|
||||
if (send_iter != node_send_map.end()) {
|
||||
new_exec_order.push_back(cnode);
|
||||
new_exec_order.push_back((*send_iter).second);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (cnode_name == kDropoutGenMaskOpName || cnode_name == kDropoutGenMaskV3OpName) {
|
||||
auto recv_iter = node_recv_map.find(cnode);
|
||||
if (recv_iter != node_recv_map.end()) {
|
||||
new_exec_order.push_back((*recv_iter).second);
|
||||
new_exec_order.push_back(cnode);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
new_exec_order.push_back(cnode);
|
||||
}
|
||||
graph_ptr->set_execution_order(new_exec_order);
|
||||
MS_LOG(INFO) << "Print execution order after inserting event between DropoutDoMask and DropoutGenMask.";
|
||||
graph_ptr->PrintGraphExecuteOrder();
|
||||
}
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -156,6 +156,8 @@ class AscendStreamAssign {
|
|||
void GetAllGraphID(const NotNull<KernelGraphPtr> &graph_ptr, std::vector<uint32_t> *graphs_id);
|
||||
void GraphLoopSync(const NotNull<KernelGraphPtr> &root_graph, uint32_t graph_id);
|
||||
|
||||
void InsertEventForMicroBatchIndependent(const NotNull<KernelGraphPtr> &graph_ptr);
|
||||
|
||||
bool independent_stream_activated_{false};
|
||||
bool hcom_stream_activated_{false};
|
||||
bool loop_sink_{false};
|
||||
|
|
|
@ -514,6 +514,7 @@ constexpr auto kAttrFuncType = "func_type";
|
|||
constexpr auto kAttrCustAicpu = "cust_aicpu";
|
||||
constexpr auto kAttrIsInternalOutputNopNode = "is_internal_output_nop_node";
|
||||
constexpr auto kAttrIsUBFusionOp = "is_ub_fusion_op";
|
||||
constexpr auto kAttrMicro = "micro";
|
||||
|
||||
// custom operator func type
|
||||
constexpr auto kCustomTypeAOT = "aot";
|
||||
|
|
|
@ -190,8 +190,7 @@ class TransformerRecomputeConfig(_Config):
|
|||
self._recompute_slice_activation = value
|
||||
|
||||
|
||||
_DEFALUT_TRANSFORMER_RECOMPUTE_CONFIG = TransformerRecomputeConfig()
|
||||
|
||||
default_transformer_recompute_config = TransformerRecomputeConfig()
|
||||
|
||||
class TransformerOpParallelConfig(_Config):
|
||||
r"""
|
||||
|
@ -231,7 +230,7 @@ class TransformerOpParallelConfig(_Config):
|
|||
"""
|
||||
|
||||
def __init__(self, data_parallel=1, model_parallel=1, pipeline_stage=1, micro_batch_num=1,
|
||||
recompute=_DEFALUT_TRANSFORMER_RECOMPUTE_CONFIG,
|
||||
recompute=default_transformer_recompute_config,
|
||||
optimizer_shard=False, gradient_aggregation_group=4, vocab_emb_dp=True):
|
||||
self.recompute = recompute
|
||||
self.optimizer_shard = optimizer_shard
|
||||
|
|
Loading…
Reference in New Issue