add pipeline shard interface

Add support for no pipeline accugradient

Add delay tag for fusion op

Optimizer the visite order

add mirror for mini step control

Move the group to attributes

Add gradient_shard control for the mini step

Fix code stype

Fix ut description

Add interface
This commit is contained in:
huangxinjing 2021-08-22 17:24:21 +08:00
parent a3441bbfb5
commit f354ab22a3
21 changed files with 310 additions and 38 deletions

View File

@ -389,6 +389,9 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr
auto final_node_prim = GetCNodePrimitive(final_node);
fused_prim->set_instance_name(final_node_prim->instance_name());
}
if (AnfAlgo::HasNodeAttr(kAttrNotDelayFusion, final_node)) {
AnfAlgo::CopyNodeAttr(kAttrNotDelayFusion, final_node, fused_node);
}
return fused_node;
}

View File

@ -1721,13 +1721,18 @@ bool AnfRuntimeAlgorithm::IsFusedCommunicationOp(const AnfNodePtr &node) {
auto primitive = AnfAlgo::GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(primitive);
ValuePtr attr_fusion = primitive->GetAttr(kAttrFusion);
ValuePtr attr_not_delay_fusion = primitive->GetAttr(kAttrNotDelayFusion);
if (attr_fusion == nullptr) {
return false;
}
auto fusion = GetValue<int64_t>(attr_fusion);
if (fusion == 0) {
return false;
}
if (attr_not_delay_fusion && GetValue<bool>(attr_not_delay_fusion)) {
return false;
}
return true;
}

View File

@ -74,6 +74,7 @@ void ParallelContext::Reset() {
optimizer_weight_shard_aggregated_save_ = false;
sharding_propagation_ = false;
enable_all2all_ = false;
grad_accumulation_shard_ = true;
dataset_strategy_.clear();
}

View File

@ -127,6 +127,10 @@ class ParallelContext {
void set_hccl_test_available(bool hccl_test_available) { hccl_test_available_ = hccl_test_available; }
bool hccl_test_available() const { return hccl_test_available_; }
void set_grad_accumulation_shard(const bool grad_accumulation_shard) {
grad_accumulation_shard_ = grad_accumulation_shard;
}
bool grad_accumulation_shard() const { return grad_accumulation_shard_; }
bool set_communi_parallel_mode(const std::string &communi_parallel_mode);
std::string communi_parallel_mode() const { return communi_parallel_mode_; }
@ -174,6 +178,7 @@ class ParallelContext {
std::string communi_parallel_mode_;
int64_t optimizer_weight_shard_size_;
bool optimizer_weight_shard_aggregated_save_;
bool grad_accumulation_shard_;
// In AUTO_PARALLEL mode, 'sharding_propagation_' = True indicates that sharding-configured operators
// will propagate the sharding strategies to other operators with minimum redistribution cost.
bool sharding_propagation_;

View File

@ -14,10 +14,10 @@
* limitations under the License.
*/
#include <iterator>
#include <memory>
#include <list>
#include <set>
#include <queue>
#include <algorithm>
#include "frontend/parallel/graph_util/pipeline_split_utils.h"
#include "frontend/parallel/graph_util/generate_graph.h"
@ -108,14 +108,37 @@ void SetStridedSliceStrategy(const AnfNodePtr &node) {
cnode->AddPrimalAttr(IN_STRATEGY, strategy);
}
CNodePtr FindNodeWithMircoSize(const AnfNodePtr &node_user, const FuncGraphManagerPtr &manager,
const NodeUsersMap &node_users_map) {
// Recursively find micro tags, this may takes much more time if layers are too much
std::queue<AnfNodePtr> visited;
visited.push(node_user);
while (!visited.empty()) {
auto cur_node = visited.front();
visited.pop();
auto users = node_users_map.at(cur_node);
for (auto &temp_user : users) {
auto cnode = temp_user.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!cnode->HasPrimalAttr(MICRO)) {
visited.push(temp_user.first);
} else {
return cnode;
}
}
}
return nullptr;
}
void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const FuncGraphManagerPtr &manager,
const AnfNodePtr &accu_parameter) {
const AnfNodePtr &accu_parameter, const NodeUsersMap &node_user_map) {
auto cnode = node_user.first->cast<CNodePtr>();
if (IsPrimitiveCNode(cnode, prim::kPrimReceive) || !cnode->in_forward_flag()) {
return;
}
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer();
bool grad_accumulation_shard = ParallelContext::GetInstance()->grad_accumulation_shard();
if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && enable_parallel_optimizer) {
return;
}
@ -124,9 +147,36 @@ void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const F
MS_LOG(WARNING) << cnode->DebugString() << " can not insert _VirtualAssignAdd.";
return;
}
OperatorAttrs attrs;
auto param_ptr = accu_parameter->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
// If grad_accumulation_shard is ture, a ReduceScatter will be inserted at each micro step,
// So the fusion id should be different for each micro step
// otherwise they will be fused into the one ReduceScatter alone micro_steps.
// if grad_accumulation_shard is false, we pass an empty group, so no ReduceScatter will be inserted
ValuePtr args1 = nullptr;
ValuePtr args2 = nullptr;
ValuePtr micro = nullptr;
int64_t step = 0;
if (grad_accumulation_shard) {
auto cnode_with_micro_size = FindNodeWithMircoSize(cnode, manager, node_user_map);
if (cnode_with_micro_size && cnode_with_micro_size->HasPrimalAttr(MICRO)) {
micro = cnode_with_micro_size->GetPrimalAttr(MICRO);
step = GetValue<int64_t>(micro);
}
}
args1 = MakeValue(param_ptr->user_data<TensorLayout>()->opt_shard_group());
args2 = MakeValue(param_ptr->param_info()->comm_fusion() + step * PIPELINE_FUSTION_OFFSET);
OperatorAttrs attrs = {};
auto py_instance = CreatOpInstance(attrs, VIRTUAL_ASSIGN_ADD, VIRTUAL_ASSIGN_ADD);
auto value_node = NewValueNode(py_instance);
// Set the attribute of the reduce scatter
auto new_prim = GetValueNode<PrimitivePtr>(value_node);
MS_EXCEPTION_IF_NULL(new_prim);
auto attrs_prim = new_prim->attrs();
attrs_prim[GROUP] = args1;
attrs_prim[kAttrFusion] = args2;
new_prim->SetAttrs(attrs_prim);
std::vector<AnfNodePtr> virtual_node_input = {value_node, cnode->input(IntToSize(node_user.second)), accu_parameter};
auto graph = cnode->func_graph();
auto virtual_node = graph->NewCNode(virtual_node_input);
@ -189,16 +239,47 @@ void HandleReceiveParam(const FuncGraphPtr &root, const std::vector<AnfNodePtr>
IsPrimitiveCNode(temp_node, prim::kPrimMicroStepAllGather)) {
auto node_set = node_users_map[temp_node];
for (auto &node_user : node_set) {
InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter);
InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter, node_users_map);
}
} else {
InsertVirtualAssignAdd(temp_user, root->manager(), accu_parameter);
InsertVirtualAssignAdd(temp_user, root->manager(), accu_parameter, node_users_map);
}
}
InsertVirtualAccuGrad(node, root->manager(), accu_parameter);
}
}
// If the graph likes the followings:
// 1. MicroStepAllGather->MirrorMicro->load, we need to visit the param after the load
std::vector<std::pair<AnfNodePtr, int>> FindNextNode(const std::pair<AnfNodePtr, int> &node_ptr,
const FuncGraphPtr &root, const NodeUsersMap &node_users_map) {
std::vector<std::pair<AnfNodePtr, int>> to_be_visited_set;
if (!IsPrimitiveCNode(node_ptr.first, prim::kPrimMirrorMicroStep) &&
!IsPrimitiveCNode(node_ptr.first, prim::kPrimMicroStepAllGather)) {
to_be_visited_set.emplace_back(node_ptr);
return to_be_visited_set;
}
auto node_set = node_users_map.at(node_ptr.first);
std::queue<std::pair<std::shared_ptr<AnfNode>, int>> visited;
for (auto &node_user : node_set) {
visited.push(node_user);
}
while (visited.size() >= 1) {
auto node = visited.front();
visited.pop();
if (!IsPrimitiveCNode(node.first, prim::kPrimMirrorMicroStep) &&
!IsPrimitiveCNode(node.first, prim::kPrimMicroStepAllGather)) {
to_be_visited_set.emplace_back(node);
} else {
auto next_node_set = node_users_map.at(node.first);
for (auto &node_user : next_node_set) {
visited.push(node_user);
}
}
}
return to_be_visited_set;
}
void AddVirtualAssignAdd(const FuncGraphPtr &root) {
auto parameters = root->parameters();
auto node_users_map = root->manager()->node_users();
@ -210,19 +291,14 @@ void AddVirtualAssignAdd(const FuncGraphPtr &root) {
}
auto node_users = node_users_map[parameter];
for (auto &temp_user : node_users) {
auto temp_node = temp_user.first;
// Micro virtual operator might be inserted after cast
if (IsPrimitiveCNode(temp_node, prim::kPrimCast)) {
temp_node = node_users_map[temp_node].begin()->first;
auto temp_node = temp_user;
if (IsPrimitiveCNode(temp_node.first, prim::kPrimCast)) {
temp_node = *node_users_map[temp_node.first].begin();
}
if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep) ||
IsPrimitiveCNode(temp_node, prim::kPrimMicroStepAllGather)) {
auto node_set = node_users_map[temp_node];
auto node_set = FindNextNode(temp_node, root, node_users_map);
for (auto &node_user : node_set) {
InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter);
}
} else {
InsertVirtualAssignAdd(temp_user, root->manager(), accu_parameter);
InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter, node_users_map);
}
}
}

View File

@ -29,7 +29,7 @@ using PipelinePair = std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>>
AnfNodePtr FindAccuGrad(const CNodePtr &cnode);
bool IsLastStage();
void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const FuncGraphManagerPtr &manager,
const AnfNodePtr &accu_parameter);
const AnfNodePtr &accu_parameter, const NodeUsersMap &node_user_map);
void InsertVirtualAccuGrad(const AnfNodePtr &recv, const FuncGraphManagerPtr &manager, const AnfNodePtr &param);
AnfNodePtr FindGradAccuParameter(const std::vector<AnfNodePtr> &parameters, const std::string &name);
void HandleReceiveParam(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes);

View File

@ -380,6 +380,24 @@ void AddCommOpMeanFlag(const CNodePtr &comm_node) {
prim->SetAttrs(attrs);
}
void AddCommOpMirrorFlag(const CNodePtr &comm_node, bool do_mirror) {
MS_EXCEPTION_IF_NULL(comm_node);
auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
auto attrs = prim->attrs();
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
attrs[DO_MIRROR] = MakeValue<bool>(do_mirror);
prim->SetAttrs(attrs);
}
void AddCommOpAddAccuFlag(const CNodePtr &comm_node, bool add_accu) {
MS_EXCEPTION_IF_NULL(comm_node);
auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
auto attrs = prim->attrs();
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
attrs[ADD_ACCU] = MakeValue<bool>(add_accu);
prim->SetAttrs(attrs);
}
void AddCommOpParamFlag(const CNodePtr &comm_node) {
MS_EXCEPTION_IF_NULL(comm_node);
auto graph = comm_node->func_graph();
@ -437,7 +455,6 @@ Operator CreateMiniStepAllGatherOp(const std::string &group) {
Operator CreateMicroStepAllGatherOp(const std::string &group) {
bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
OperatorName operator_name = MICRO_STEP_ALL_GATHER;
ValuePtr attr0_value = MakeValue(group); // group
Attr attr0 = std::make_pair(GROUP, attr0_value);

View File

@ -303,6 +303,8 @@ Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &
Operator CreateAllGatherOp(const std::string &group);
Operator CreateMiniStepAllGatherOp(const std::string &group);
void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &param_node);
void AddCommOpMirrorFlag(const CNodePtr &comm_node, bool do_mirror);
void AddCommOpAddAccuFlag(const CNodePtr &comm_node, bool add_accu);
Operator CreateMicroStepAllGatherOp(const std::string &group);
void AddCommOpMeanFlag(const CNodePtr &comm_node);
void AddCommOpParamFlag(const CNodePtr &comm_node);

View File

@ -141,6 +141,7 @@ constexpr char STRIDES[] = "strides";
constexpr char GROUP[] = "group";
constexpr char FUSION[] = "fusion";
constexpr char DO_MIRROR[] = "do_mirror";
constexpr char ADD_ACCU[] = "add_accu";
constexpr char RECOMPUTE[] = "recompute";
constexpr char RECOMPUTE_COMM_OP[] = "recompute_comm_op";
constexpr char NOT_RECOMPUTE[] = "not_recompute";
@ -407,6 +408,7 @@ constexpr char RESIZE_BILINEAR[] = "ResizeBilinear";
constexpr char RESIZE_NEAREST_NEIGHBOR[] = "ResizeNearestNeighbor";
// pipeline
constexpr size_t PIPELINE_FUSTION_OFFSET = 100;
constexpr char MICRO[] = "micro";
constexpr char DEST_RANK[] = "dest_rank";
constexpr char SRC_RANK[] = "src_rank";

View File

@ -453,6 +453,8 @@ void HandleFullySplitParameters(const FuncGraphPtr &root) {
void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
MS_EXCEPTION_IF_NULL(root);
auto grad_accumulation_shard = ParallelContext::GetInstance()->grad_accumulation_shard();
for (auto &cloned_parameter_node : root->parameters()) {
MS_EXCEPTION_IF_NULL(cloned_parameter_node);
auto cloned_parameter = cloned_parameter_node->cast<ParameterPtr>();
@ -512,11 +514,20 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
// from pipeline or grad accumulation
if (param_name.find(ACCU_GRADS) != std::string::npos) {
auto slice_shape = cloned_from_parameter->user_data<TensorLayout>()->slice_shape().array();
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
auto opt_shard_group = tensor_layout->opt_shard_group();
auto opt_shard_shape = cloned_from_parameter->user_data<TensorLayout>()->opt_shard_slice_shape();
std::shared_ptr<abstract::BaseShape> parallel_shape = nullptr;
// set opt shard shape if the pipeline sharding is set
if (grad_accumulation_shard && !opt_shard_group.empty()) {
parallel_shape = std::make_shared<abstract::Shape>(opt_shard_shape);
} else {
parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
}
MS_EXCEPTION_IF_NULL(parallel_shape);
cloned_abstract->set_shape(parallel_shape);
// in opt shard, accu_grad's shape is different from the original param's shape
if (ParallelContext::GetInstance()->enable_parallel_optimizer()) {
// if the grad_accumulation_shard is enabled, the accu_grads will be a opt-sharded shape
if (!grad_accumulation_shard && ParallelContext::GetInstance()->enable_parallel_optimizer()) {
TensorLayout new_layout = *tensor_layout;
new_layout.set_opt_shard_group("");
tensor_layout = std::make_shared<TensorLayout>(new_layout);
@ -526,6 +537,13 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
}
cloned_parameter->set_user_data<TensorLayout>(tensor_layout);
cloned_parameter_node->set_abstract(cloned_abstract);
// copy the fusion tag
auto cloned_param_info = cloned_parameter->param_info();
MS_EXCEPTION_IF_NULL(cloned_param_info);
auto cloned_from_param_info = cloned_from_parameter->param_info();
MS_EXCEPTION_IF_NULL(cloned_from_param_info);
cloned_param_info->set_comm_fusion(cloned_from_param_info->comm_fusion());
MS_LOG(INFO) << "The parameter: " << cloned_parameter->name()
<< " is cloned, the be cloned parameter is: " << cloned_from_parameter->name()
<< ", clone index is: " << cloned_index;

View File

@ -64,7 +64,7 @@ static const std::set<std::string> NO_INPUT_TENSOR_OPS = {UNIFORM_REAL};
// it will be one item in map with key: C, and value: (B, i)
std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap;
void SetMiniStepOpDoMirrorLabel(std::vector<AnfNodePtr> new_node_input, bool accu_flag) {
void SetMiniStepOpDoMirrorLabel(std::vector<AnfNodePtr> new_node_input, bool do_mirror, bool accu_flag) {
if (new_node_input.empty()) {
return;
}
@ -73,7 +73,8 @@ void SetMiniStepOpDoMirrorLabel(std::vector<AnfNodePtr> new_node_input, bool acc
MS_EXCEPTION_IF_NULL(prim);
auto attrs = prim->attrs();
attrs[DO_MIRROR] = MakeValue<bool>(!accu_flag);
attrs[DO_MIRROR] = MakeValue<bool>(do_mirror);
attrs[ADD_ACCU] = MakeValue<bool>(accu_flag);
prim->SetAttrs(attrs);
}
@ -189,7 +190,9 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
SetCommunicationOpGroupLabel(new_node_input);
// gradient accumulation
if (grad_accumulation_step > 1) {
SetMiniStepOpDoMirrorLabel(new_node_input, root->has_flag(ACCUMULATION));
bool add_accu = root->has_flag(ACCUMULATION);
// MiniStep need to do mirror at each micro step as we use the gradient accumulation sharding,
SetMiniStepOpDoMirrorLabel(new_node_input, !add_accu, !add_accu);
}
return new_node_input;
}
@ -1510,6 +1513,7 @@ static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group
const AnfNodePtr &node, const std::string &op_name, bool is_shared_param) {
MS_EXCEPTION_IF_NULL(res.first);
MS_EXCEPTION_IF_NULL(node);
bool grad_accumulation_shard = ParallelContext::GetInstance()->grad_accumulation_shard();
auto cnode = res.first->cast<CNodePtr>();
auto graph = cnode->func_graph();
MS_EXCEPTION_IF_NULL(graph);
@ -1528,6 +1532,12 @@ static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group
op = CreateAllGatherOp(group);
}
CNodePtr cast_node = InsertAllGatherAfterCast(cnode);
std::string opt_shard_mirror_group;
auto param_ptr = node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (param_ptr->user_data<TensorLayout>()) {
opt_shard_mirror_group = param_ptr->user_data<TensorLayout>()->opt_shard_mirror_group();
}
if (!is_shared_param && cast_node) {
allgather = ReplaceNode(op, cast_node, graph, PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE, param_name, root);
MS_LOG(INFO) << "Parallel optimizer is applied before Cast for " << param_name;
@ -1541,6 +1551,17 @@ static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group
AddCommOpFusionType(allgather, node);
// add gradients mean
AddCommOpMeanFlag(allgather);
if (op_name == MICRO_STEP_ALL_GATHER) {
// When grad_accumulation_shard is enabled, the ReduceScatter is inserted at each micro step
// so no need to do backward for the micro_step_allgather
AddCommOpMirrorFlag(allgather, !grad_accumulation_shard);
} else if (op_name == MINI_STEP_ALL_GATHER) {
// We need to manually set the add_accu to be false if it's father node is MirrorMiniStep
bool add_accu = root->has_flag(ACCUMULATION);
bool is_with_mirror = opt_shard_mirror_group.size() > 1;
AddCommOpAddAccuFlag(allgather, !add_accu && !is_with_mirror);
AddCommOpMirrorFlag(allgather, grad_accumulation_shard || !add_accu);
}
}
static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &parameter,

View File

@ -140,6 +140,8 @@ PYBIND11_MODULE(_c_expression, m) {
.def("get_device_num_is_set", &ParallelContext::device_num_is_set, "Get device num is set.")
.def("get_global_rank", &ParallelContext::global_rank, "Get global rank.")
.def("set_global_rank", &ParallelContext::set_global_rank, "Set global rank.")
.def("get_grad_accumulation_shard", &ParallelContext::grad_accumulation_shard, "Get grad_accumulation_shard.")
.def("set_grad_accumulation_shard", &ParallelContext::set_grad_accumulation_shard, "Set grad_accumulation_shard.")
.def("get_global_rank_is_set", &ParallelContext::global_rank_is_set, "Get global rank is set.")
.def("get_gradients_mean", &ParallelContext::gradients_mean, "Get mirror mean.")
.def("set_gradients_mean", &ParallelContext::set_gradients_mean, "Set mirror mean.")

View File

@ -503,10 +503,11 @@ CNodePtr AscendStreamAssign::GetTargetOutputNode(const vector<CNodePtr> &moved_b
return nullptr;
}
for (; it < cnode_ptr_list.end() && AnfAlgo::GetGraphId((*it).get()) != subgraph_id; it++) {
for (; it < cnode_ptr_list.end(); it++) {
auto inputs = GetInputKernels(*it);
for (auto &input : inputs) {
if (find(moved_backward_cnodes.begin(), moved_backward_cnodes.end(), input) != moved_backward_cnodes.end()) {
if (find(moved_backward_cnodes.begin(), moved_backward_cnodes.end(), input) != moved_backward_cnodes.end() &&
AnfAlgo::GetGraphId((*it).get()) != subgraph_id) {
MS_LOG(INFO) << "The nodes moved backward were used by nodes on different subgraphs, no need moved";
return nullptr;
}

View File

@ -387,6 +387,7 @@ constexpr auto kAttrN = "n";
constexpr auto kAttrLabelForInsertStreamActive = "label_for_insert_stream_active";
constexpr auto kAttrFpBpEnd = "fpbp_end";
constexpr auto kAttrFusion = "fusion";
constexpr auto kAttrNotDelayFusion = "not_delay_fusion";
constexpr auto kAttrGroup = "group";
constexpr auto kAttrGroups = "groups";
constexpr auto kAttrGroupBack = "group_back";

View File

@ -349,7 +349,8 @@ def _context():
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str,
auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int)
all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int,
parallel_optimizer_config=dict)
def set_auto_parallel_context(**kwargs):
r"""
Set auto parallel context, which is valid only for Ascend and GPU target.
@ -375,7 +376,7 @@ def set_auto_parallel_context(**kwargs):
parallel_mode strategy_ckpt_load_file
all_reduce_fusion_config strategy_ckpt_save_file
enable_parallel_optimizer dataset_strategy
\ pipeline_stages
parallel_optimizer_config pipeline_stages
\ grad_accumulation_step
=========================== ===========================
@ -432,6 +433,21 @@ def set_auto_parallel_context(**kwargs):
Default: 1.
grad_accumulation_step (int): Set the accumulation steps of gradients in auto and semi auto parallel mode.
This should be a positive int. Default: 1.
parallel_optimizer_config (dict): A dict contains the keys and values for setting the parallel optimizer
configure. The configure provides more detailed behavior control about parallel training
when parallel optimizer is enabled. Currently it supports the key `gradient_accumulation_shard`.
The configure will be effective when we use
context.set_auto_parallel_context(enable_parallel_optimizer=True).
It supports the following keys.
- gradient_accumulation_shard: If ture, the accumulation gradient parameters will be
sharded across the data parallel devices. This will
introduce additional communication(ReduceScatter) at
each step when accumulate the gradients, but saves a
lot of device memories, thus can make model be trained
with larger batch size. This configure is effective only
when the model runs on pipeline training or gradient
accumulation with data parallel.
Raises:
ValueError: If input key is not attribute in auto parallel context.
@ -451,6 +467,8 @@ def set_auto_parallel_context(**kwargs):
>>> context.set_auto_parallel_context(enable_parallel_optimizer=False)
>>> context.set_auto_parallel_context(all_reduce_fusion_config=[8, 160])
>>> context.set_auto_parallel_context(pipeline_stages=2)
>>> parallel_config = {"gradient_accumulation_shard": True}
>>> context.set_auto_parallel_context(parallel_optimizer_config=parallel_config, enable_parallel_optimizer=True)
"""
_set_auto_parallel_context(**kwargs)

View File

@ -18,7 +18,7 @@ from mindspore import Tensor
import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
from mindspore.communication import get_rank, get_group_size
from mindspore.parallel._utils import _get_enable_parallel_optimizer
from mindspore.parallel._utils import _get_enable_parallel_optimizer, _get_grad_accumulation_shard
from .. import operations as P
from ...common.tensor import RowTensor
from ..composite.multitype_ops.zeros_like_impl import zeros_like
@ -131,8 +131,22 @@ def get_bprop_virtual_assign_add(self):
cast = P.Cast()
dtype = P.DType()
out_tensor = Tensor(0.0, mstype.float16)
reduce_scatter = None
group = self.get_attr_dict().get("group", None)
fusion = self.get_attr_dict().get("fusion", 0)
if group:
reduce_scatter = ReduceScatter(ReduceOp.SUM, group).add_prim_attr("fusion", fusion)
if self.instance_name:
instance_name = "_grad_accumulation_shard_grad" + self.instance_name
reduce_scatter.set_prim_instance_name(instance_name)
# For pipeline training, as the fused communication will be visited later
# this may make memory increase, so we need to add a tag to let the
# fused communication not be effective.
reduce_scatter.add_prim_attr("not_delay_fusion", True)
def bprop(x, y, out, dout):
if reduce_scatter:
dout = reduce_scatter(dout)
temp = assign_add(y, dout)
return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(y))), temp)
@ -237,8 +251,11 @@ def get_bprop_mini_step_all_gather(self):
fusion = self.get_attr_dict()["fusion"]
mean_flag = self.get_attr_dict()["mean_flag"]
do_mirror = self.get_attr_dict()["do_mirror"]
add_accu = self.get_attr_dict().get("add_accu", False)
gradient_shard = _get_grad_accumulation_shard()
scale = 1 / self.rank_size
all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion)
assign_add = P.AssignAdd()
if self.instance_name:
instance_name = "grad_" + self.instance_name
all_reduce.set_prim_instance_name(instance_name)
@ -248,15 +265,21 @@ def get_bprop_mini_step_all_gather(self):
def bprop(x, z, out, dout):
if do_mirror:
if mean_flag:
if not gradient_shard:
z = F.depend(z, F.assign_add(z, dout))
grad = all_reduce(z)
dx = split(grad)[rank]
if mean_flag:
dx = F.tensor_mul(dx, scale)
else:
z = F.depend(z, F.assign_add(z, dout))
grad = all_reduce(z)
dout = F.depend(dout, z)
grad = all_reduce(dout)
dx = split(grad)[rank]
if mean_flag:
dx = F.tensor_mul(dx, scale)
if add_accu:
z = assign_add(z, dx)
dx = F.depend(dx, z)
else:
dx = dout
return (dx, zeros_like(z))
@ -269,6 +292,7 @@ def get_bprop_micro_step_all_gather(self):
"""Generate bprop for _MicroStepAllGather"""
fusion = self.get_attr_dict()["fusion"]
mean_flag = self.get_attr_dict()["mean_flag"]
do_mirror = self.get_attr_dict()["do_mirror"]
scale = 1 / self.rank_size
all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion)
rank = get_rank(self.group)
@ -284,6 +308,8 @@ def get_bprop_micro_step_all_gather(self):
# z: accu_grad
def bprop(x, z, out, dout):
z = F.depend(z, dout)
if not do_mirror:
return (z, cast(out_tensor, dtype(z)))
real_grad = all_reduce(z)
real_grad = split(real_grad)[rank]
if mean_flag:

View File

@ -5,4 +5,4 @@ e
bprop.8:x*
bprop.8:out*
bprop.8:dout2
bprop.8:[CNode]:1:@74787be4234cdeb03f214519cd8358a5f4ad2f5606dbeb494462cddc448eb4beP
bprop.8:[CNode]:1:@96c75d48466ae9dd2ae51ee64181426e1bf1c36337f7c6cf3bdd01083bfb1a6eP

View File

@ -299,6 +299,7 @@ class _MicroStepAllGather(PrimitiveWithInfer):
self.add_prim_attr('rank_size', self.rank_size)
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('fusion', 1)
self.add_prim_attr('do_mirror', False)
self.mean_flag = mean_flag
def infer_shape(self, x_shape, z_shape):

View File

@ -13,9 +13,9 @@
# limitations under the License.
# ============================================================================
"""Context of auto parallel"""
import os
import threading
import mindspore.context as context
from mindspore import context
import mindspore.log as logger
from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
from mindspore.parallel._ps_context import _is_role_pserver
@ -27,6 +27,13 @@ _DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1"
_DEFAULT_NCCL_FUSION_GROUP_NAME = "nccl_world_groupsum1"
class _ParallelOptimizerConfig:
"""
The key of the Parallel Optimizer. There are three
"""
GRADIENT_ACCUMULATION_SHARD = "gradient_accumulation_shard"
class _AutoParallelContext:
"""
_AutoParallelContext is the environment in which operations are executed
@ -326,7 +333,6 @@ class _AutoParallelContext:
strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint.
"""
self.check_context_handle()
import os
dir_path = os.path.dirname(strategy_ckpt_save_file)
if dir_path and not os.path.exists(dir_path):
os.makedirs(dir_path)
@ -340,7 +346,6 @@ class _AutoParallelContext:
def set_group_ckpt_save_file(self, group_ckpt_save_file):
"""Set group checkpoint save path."""
self.check_context_handle()
import os
dir_path = os.path.dirname(group_ckpt_save_file)
if dir_path and not os.path.exists(dir_path):
os.makedirs(dir_path)
@ -489,6 +494,41 @@ class _AutoParallelContext:
self.check_context_handle()
return self._context_handle.get_enable_parallel_optimizer()
def set_parallel_optimizer_config(self, parallel_optimizer_config):
"""
Set the configure for parallel optimizer. The configure provides more detailed behavior control about parallel
training when parallel optimizer is enabled.
Currently it supports the key `gradient_accumulation_shard`. The configure will be effective
when we use context.set_auto_parallel_context(enable_parallel_optimizer=True).
Args:
parallel_optimizer_config(dict): A dict contains the keys and values for setting the parallel optimizer
configure. It supports the following keys:
- gradient_accumulation_shard: If ture, the accumulation gradient parameters will be sharded
across the data parallel devices. This will introduce additional
communication(ReduceScatter) at each step when accumulate the
gradients, but saves a lot of device memories,
thus can make model be trained with larger batch size.
This configure is effective only when the model runs on pipeline
training or gradient accumulation with data parallel.
"""
self.check_context_handle()
grad_shard_name = _ParallelOptimizerConfig.GRADIENT_ACCUMULATION_SHARD
if grad_shard_name in parallel_optimizer_config:
Validator.check_bool(
parallel_optimizer_config[grad_shard_name], grad_shard_name, grad_shard_name)
self._context_handle.set_grad_accumulation_shard(
parallel_optimizer_config[grad_shard_name])
else:
raise ValueError(f"The parallel_optimizer_config doest not contains {grad_shard_name}, please check your "
f"parallel_optimizer_config")
def get_grad_accumulation_shard(self):
self.check_context_handle()
return self._context_handle.get_grad_accumulation_shard()
def set_sharding_propagation(self, sharding_propagation):
"""
Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True, the strategy-configured operators
@ -648,6 +688,7 @@ _set_auto_parallel_context_func_map = {
"full_batch": auto_parallel_context().set_full_batch,
"dataset_strategy": auto_parallel_context().set_dataset_strategy,
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
"parallel_optimizer_config": auto_parallel_context().set_parallel_optimizer_config,
"grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step,
"all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices,
"communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode,
@ -802,5 +843,6 @@ def _reset_auto_parallel_context():
- enable_parallel_optimizer: False
- auto_parallel_search_mode: dynamic_programming
- pipeline_stages: 0
- gradient_accumulation_shard: True
"""
auto_parallel_context().reset()

View File

@ -219,6 +219,11 @@ def _get_enable_parallel_optimizer():
return auto_parallel_context().get_enable_parallel_optimizer()
def _get_grad_accumulation_shard():
"""Get if using parallel shard."""
return auto_parallel_context().get_grad_accumulation_shard()
def _device_number_check(parallel_mode, device_number):
"""
Check device num.

View File

@ -18,6 +18,7 @@ import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.context import set_auto_parallel_context, ParallelMode
from mindspore import context
from mindspore.ops import composite as C
from mindspore.ops import functional as F
import mindspore.ops as P
@ -250,11 +251,18 @@ def test_transformer_model_auto_parallel_no_support():
mode=ParallelMode.AUTO_PARALLEL)
def test_pipeline_single_transformer():
def pipeline_single_transformer(grad_accumulation_shard=False):
"""
Feature: Gradient Accumulation Shard for Pipeline and Gradient Accumulation
Description: Test a single transformer model with pipeline parallel with grad_accumulation_shard False
Expectation: The compile passed
"""
set_auto_parallel_context(device_num=32,
full_batch=True,
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0,
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
context.set_auto_parallel_context(parallel_optimizer_config=
{"gradient_accumulation_shard": grad_accumulation_shard})
net = Transformer(batch_size=4 // pipeline_config.micro_batch_num,
src_seq_length=20,
@ -286,6 +294,24 @@ def test_pipeline_single_transformer():
model.train(1, dataset, dataset_sink_mode=False)
def test_pipeline_transformer_gradient_shard_true():
"""
Feature: Gradient Accumulation Shard for Pipeline and Gradient Accumulation
Description: Test a single transformer model with pipeline parallel with grad_accumulation_shard True
Expectation: The compile passed
"""
pipeline_single_transformer(grad_accumulation_shard=True)
def test_pipeline_transformer_gradient_shard_false():
"""
Feature: Gradient Accumulation Shard for Pipeline and Gradient Accumulation
Description: Test a single transformer model with pipeline parallel with grad_accumulation_shard False
Expectation: The compile passed
"""
pipeline_single_transformer(grad_accumulation_shard=False)
def test_transformer_wrong_head():
set_auto_parallel_context(device_num=32,
full_batch=True,