forked from mindspore-Ecosystem/mindspore
add pipeline shard interface
Add support for no pipeline accugradient Add delay tag for fusion op Optimizer the visite order add mirror for mini step control Move the group to attributes Add gradient_shard control for the mini step Fix code stype Fix ut description Add interface
This commit is contained in:
parent
a3441bbfb5
commit
f354ab22a3
|
@ -389,6 +389,9 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr
|
|||
auto final_node_prim = GetCNodePrimitive(final_node);
|
||||
fused_prim->set_instance_name(final_node_prim->instance_name());
|
||||
}
|
||||
if (AnfAlgo::HasNodeAttr(kAttrNotDelayFusion, final_node)) {
|
||||
AnfAlgo::CopyNodeAttr(kAttrNotDelayFusion, final_node, fused_node);
|
||||
}
|
||||
return fused_node;
|
||||
}
|
||||
|
||||
|
|
|
@ -1721,13 +1721,18 @@ bool AnfRuntimeAlgorithm::IsFusedCommunicationOp(const AnfNodePtr &node) {
|
|||
auto primitive = AnfAlgo::GetCNodePrimitive(node);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
ValuePtr attr_fusion = primitive->GetAttr(kAttrFusion);
|
||||
ValuePtr attr_not_delay_fusion = primitive->GetAttr(kAttrNotDelayFusion);
|
||||
if (attr_fusion == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto fusion = GetValue<int64_t>(attr_fusion);
|
||||
if (fusion == 0) {
|
||||
return false;
|
||||
}
|
||||
if (attr_not_delay_fusion && GetValue<bool>(attr_not_delay_fusion)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -74,6 +74,7 @@ void ParallelContext::Reset() {
|
|||
optimizer_weight_shard_aggregated_save_ = false;
|
||||
sharding_propagation_ = false;
|
||||
enable_all2all_ = false;
|
||||
grad_accumulation_shard_ = true;
|
||||
dataset_strategy_.clear();
|
||||
}
|
||||
|
||||
|
|
|
@ -127,6 +127,10 @@ class ParallelContext {
|
|||
|
||||
void set_hccl_test_available(bool hccl_test_available) { hccl_test_available_ = hccl_test_available; }
|
||||
bool hccl_test_available() const { return hccl_test_available_; }
|
||||
void set_grad_accumulation_shard(const bool grad_accumulation_shard) {
|
||||
grad_accumulation_shard_ = grad_accumulation_shard;
|
||||
}
|
||||
bool grad_accumulation_shard() const { return grad_accumulation_shard_; }
|
||||
|
||||
bool set_communi_parallel_mode(const std::string &communi_parallel_mode);
|
||||
std::string communi_parallel_mode() const { return communi_parallel_mode_; }
|
||||
|
@ -174,6 +178,7 @@ class ParallelContext {
|
|||
std::string communi_parallel_mode_;
|
||||
int64_t optimizer_weight_shard_size_;
|
||||
bool optimizer_weight_shard_aggregated_save_;
|
||||
bool grad_accumulation_shard_;
|
||||
// In AUTO_PARALLEL mode, 'sharding_propagation_' = True indicates that sharding-configured operators
|
||||
// will propagate the sharding strategies to other operators with minimum redistribution cost.
|
||||
bool sharding_propagation_;
|
||||
|
|
|
@ -14,10 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
#include <list>
|
||||
#include <set>
|
||||
#include <queue>
|
||||
#include <algorithm>
|
||||
#include "frontend/parallel/graph_util/pipeline_split_utils.h"
|
||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||
|
@ -108,14 +108,37 @@ void SetStridedSliceStrategy(const AnfNodePtr &node) {
|
|||
cnode->AddPrimalAttr(IN_STRATEGY, strategy);
|
||||
}
|
||||
|
||||
CNodePtr FindNodeWithMircoSize(const AnfNodePtr &node_user, const FuncGraphManagerPtr &manager,
|
||||
const NodeUsersMap &node_users_map) {
|
||||
// Recursively find micro tags, this may takes much more time if layers are too much
|
||||
std::queue<AnfNodePtr> visited;
|
||||
visited.push(node_user);
|
||||
while (!visited.empty()) {
|
||||
auto cur_node = visited.front();
|
||||
visited.pop();
|
||||
auto users = node_users_map.at(cur_node);
|
||||
for (auto &temp_user : users) {
|
||||
auto cnode = temp_user.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!cnode->HasPrimalAttr(MICRO)) {
|
||||
visited.push(temp_user.first);
|
||||
} else {
|
||||
return cnode;
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const FuncGraphManagerPtr &manager,
|
||||
const AnfNodePtr &accu_parameter) {
|
||||
const AnfNodePtr &accu_parameter, const NodeUsersMap &node_user_map) {
|
||||
auto cnode = node_user.first->cast<CNodePtr>();
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimReceive) || !cnode->in_forward_flag()) {
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||
bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer();
|
||||
bool grad_accumulation_shard = ParallelContext::GetInstance()->grad_accumulation_shard();
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && enable_parallel_optimizer) {
|
||||
return;
|
||||
}
|
||||
|
@ -124,9 +147,36 @@ void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const F
|
|||
MS_LOG(WARNING) << cnode->DebugString() << " can not insert _VirtualAssignAdd.";
|
||||
return;
|
||||
}
|
||||
OperatorAttrs attrs;
|
||||
auto param_ptr = accu_parameter->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param_ptr);
|
||||
// If grad_accumulation_shard is ture, a ReduceScatter will be inserted at each micro step,
|
||||
// So the fusion id should be different for each micro step
|
||||
// otherwise they will be fused into the one ReduceScatter alone micro_steps.
|
||||
// if grad_accumulation_shard is false, we pass an empty group, so no ReduceScatter will be inserted
|
||||
ValuePtr args1 = nullptr;
|
||||
ValuePtr args2 = nullptr;
|
||||
ValuePtr micro = nullptr;
|
||||
int64_t step = 0;
|
||||
if (grad_accumulation_shard) {
|
||||
auto cnode_with_micro_size = FindNodeWithMircoSize(cnode, manager, node_user_map);
|
||||
if (cnode_with_micro_size && cnode_with_micro_size->HasPrimalAttr(MICRO)) {
|
||||
micro = cnode_with_micro_size->GetPrimalAttr(MICRO);
|
||||
step = GetValue<int64_t>(micro);
|
||||
}
|
||||
}
|
||||
args1 = MakeValue(param_ptr->user_data<TensorLayout>()->opt_shard_group());
|
||||
args2 = MakeValue(param_ptr->param_info()->comm_fusion() + step * PIPELINE_FUSTION_OFFSET);
|
||||
OperatorAttrs attrs = {};
|
||||
auto py_instance = CreatOpInstance(attrs, VIRTUAL_ASSIGN_ADD, VIRTUAL_ASSIGN_ADD);
|
||||
auto value_node = NewValueNode(py_instance);
|
||||
// Set the attribute of the reduce scatter
|
||||
auto new_prim = GetValueNode<PrimitivePtr>(value_node);
|
||||
MS_EXCEPTION_IF_NULL(new_prim);
|
||||
auto attrs_prim = new_prim->attrs();
|
||||
attrs_prim[GROUP] = args1;
|
||||
attrs_prim[kAttrFusion] = args2;
|
||||
new_prim->SetAttrs(attrs_prim);
|
||||
|
||||
std::vector<AnfNodePtr> virtual_node_input = {value_node, cnode->input(IntToSize(node_user.second)), accu_parameter};
|
||||
auto graph = cnode->func_graph();
|
||||
auto virtual_node = graph->NewCNode(virtual_node_input);
|
||||
|
@ -189,16 +239,47 @@ void HandleReceiveParam(const FuncGraphPtr &root, const std::vector<AnfNodePtr>
|
|||
IsPrimitiveCNode(temp_node, prim::kPrimMicroStepAllGather)) {
|
||||
auto node_set = node_users_map[temp_node];
|
||||
for (auto &node_user : node_set) {
|
||||
InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter);
|
||||
InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter, node_users_map);
|
||||
}
|
||||
} else {
|
||||
InsertVirtualAssignAdd(temp_user, root->manager(), accu_parameter);
|
||||
InsertVirtualAssignAdd(temp_user, root->manager(), accu_parameter, node_users_map);
|
||||
}
|
||||
}
|
||||
InsertVirtualAccuGrad(node, root->manager(), accu_parameter);
|
||||
}
|
||||
}
|
||||
|
||||
// If the graph likes the followings:
|
||||
// 1. MicroStepAllGather->MirrorMicro->load, we need to visit the param after the load
|
||||
std::vector<std::pair<AnfNodePtr, int>> FindNextNode(const std::pair<AnfNodePtr, int> &node_ptr,
|
||||
const FuncGraphPtr &root, const NodeUsersMap &node_users_map) {
|
||||
std::vector<std::pair<AnfNodePtr, int>> to_be_visited_set;
|
||||
if (!IsPrimitiveCNode(node_ptr.first, prim::kPrimMirrorMicroStep) &&
|
||||
!IsPrimitiveCNode(node_ptr.first, prim::kPrimMicroStepAllGather)) {
|
||||
to_be_visited_set.emplace_back(node_ptr);
|
||||
return to_be_visited_set;
|
||||
}
|
||||
auto node_set = node_users_map.at(node_ptr.first);
|
||||
std::queue<std::pair<std::shared_ptr<AnfNode>, int>> visited;
|
||||
for (auto &node_user : node_set) {
|
||||
visited.push(node_user);
|
||||
}
|
||||
while (visited.size() >= 1) {
|
||||
auto node = visited.front();
|
||||
visited.pop();
|
||||
if (!IsPrimitiveCNode(node.first, prim::kPrimMirrorMicroStep) &&
|
||||
!IsPrimitiveCNode(node.first, prim::kPrimMicroStepAllGather)) {
|
||||
to_be_visited_set.emplace_back(node);
|
||||
} else {
|
||||
auto next_node_set = node_users_map.at(node.first);
|
||||
for (auto &node_user : next_node_set) {
|
||||
visited.push(node_user);
|
||||
}
|
||||
}
|
||||
}
|
||||
return to_be_visited_set;
|
||||
}
|
||||
|
||||
void AddVirtualAssignAdd(const FuncGraphPtr &root) {
|
||||
auto parameters = root->parameters();
|
||||
auto node_users_map = root->manager()->node_users();
|
||||
|
@ -210,19 +291,14 @@ void AddVirtualAssignAdd(const FuncGraphPtr &root) {
|
|||
}
|
||||
auto node_users = node_users_map[parameter];
|
||||
for (auto &temp_user : node_users) {
|
||||
auto temp_node = temp_user.first;
|
||||
// Micro virtual operator might be inserted after cast
|
||||
if (IsPrimitiveCNode(temp_node, prim::kPrimCast)) {
|
||||
temp_node = node_users_map[temp_node].begin()->first;
|
||||
auto temp_node = temp_user;
|
||||
if (IsPrimitiveCNode(temp_node.first, prim::kPrimCast)) {
|
||||
temp_node = *node_users_map[temp_node.first].begin();
|
||||
}
|
||||
if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep) ||
|
||||
IsPrimitiveCNode(temp_node, prim::kPrimMicroStepAllGather)) {
|
||||
auto node_set = node_users_map[temp_node];
|
||||
auto node_set = FindNextNode(temp_node, root, node_users_map);
|
||||
for (auto &node_user : node_set) {
|
||||
InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter);
|
||||
}
|
||||
} else {
|
||||
InsertVirtualAssignAdd(temp_user, root->manager(), accu_parameter);
|
||||
InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter, node_users_map);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@ using PipelinePair = std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>>
|
|||
AnfNodePtr FindAccuGrad(const CNodePtr &cnode);
|
||||
bool IsLastStage();
|
||||
void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const FuncGraphManagerPtr &manager,
|
||||
const AnfNodePtr &accu_parameter);
|
||||
const AnfNodePtr &accu_parameter, const NodeUsersMap &node_user_map);
|
||||
void InsertVirtualAccuGrad(const AnfNodePtr &recv, const FuncGraphManagerPtr &manager, const AnfNodePtr ¶m);
|
||||
AnfNodePtr FindGradAccuParameter(const std::vector<AnfNodePtr> ¶meters, const std::string &name);
|
||||
void HandleReceiveParam(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes);
|
||||
|
|
|
@ -380,6 +380,24 @@ void AddCommOpMeanFlag(const CNodePtr &comm_node) {
|
|||
prim->SetAttrs(attrs);
|
||||
}
|
||||
|
||||
void AddCommOpMirrorFlag(const CNodePtr &comm_node, bool do_mirror) {
|
||||
MS_EXCEPTION_IF_NULL(comm_node);
|
||||
auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
|
||||
auto attrs = prim->attrs();
|
||||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||
attrs[DO_MIRROR] = MakeValue<bool>(do_mirror);
|
||||
prim->SetAttrs(attrs);
|
||||
}
|
||||
|
||||
void AddCommOpAddAccuFlag(const CNodePtr &comm_node, bool add_accu) {
|
||||
MS_EXCEPTION_IF_NULL(comm_node);
|
||||
auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
|
||||
auto attrs = prim->attrs();
|
||||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||
attrs[ADD_ACCU] = MakeValue<bool>(add_accu);
|
||||
prim->SetAttrs(attrs);
|
||||
}
|
||||
|
||||
void AddCommOpParamFlag(const CNodePtr &comm_node) {
|
||||
MS_EXCEPTION_IF_NULL(comm_node);
|
||||
auto graph = comm_node->func_graph();
|
||||
|
@ -437,7 +455,6 @@ Operator CreateMiniStepAllGatherOp(const std::string &group) {
|
|||
|
||||
Operator CreateMicroStepAllGatherOp(const std::string &group) {
|
||||
bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
|
||||
|
||||
OperatorName operator_name = MICRO_STEP_ALL_GATHER;
|
||||
ValuePtr attr0_value = MakeValue(group); // group
|
||||
Attr attr0 = std::make_pair(GROUP, attr0_value);
|
||||
|
|
|
@ -303,6 +303,8 @@ Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &
|
|||
Operator CreateAllGatherOp(const std::string &group);
|
||||
Operator CreateMiniStepAllGatherOp(const std::string &group);
|
||||
void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node);
|
||||
void AddCommOpMirrorFlag(const CNodePtr &comm_node, bool do_mirror);
|
||||
void AddCommOpAddAccuFlag(const CNodePtr &comm_node, bool add_accu);
|
||||
Operator CreateMicroStepAllGatherOp(const std::string &group);
|
||||
void AddCommOpMeanFlag(const CNodePtr &comm_node);
|
||||
void AddCommOpParamFlag(const CNodePtr &comm_node);
|
||||
|
|
|
@ -141,6 +141,7 @@ constexpr char STRIDES[] = "strides";
|
|||
constexpr char GROUP[] = "group";
|
||||
constexpr char FUSION[] = "fusion";
|
||||
constexpr char DO_MIRROR[] = "do_mirror";
|
||||
constexpr char ADD_ACCU[] = "add_accu";
|
||||
constexpr char RECOMPUTE[] = "recompute";
|
||||
constexpr char RECOMPUTE_COMM_OP[] = "recompute_comm_op";
|
||||
constexpr char NOT_RECOMPUTE[] = "not_recompute";
|
||||
|
@ -407,6 +408,7 @@ constexpr char RESIZE_BILINEAR[] = "ResizeBilinear";
|
|||
constexpr char RESIZE_NEAREST_NEIGHBOR[] = "ResizeNearestNeighbor";
|
||||
|
||||
// pipeline
|
||||
constexpr size_t PIPELINE_FUSTION_OFFSET = 100;
|
||||
constexpr char MICRO[] = "micro";
|
||||
constexpr char DEST_RANK[] = "dest_rank";
|
||||
constexpr char SRC_RANK[] = "src_rank";
|
||||
|
|
|
@ -453,6 +453,8 @@ void HandleFullySplitParameters(const FuncGraphPtr &root) {
|
|||
|
||||
void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
auto grad_accumulation_shard = ParallelContext::GetInstance()->grad_accumulation_shard();
|
||||
|
||||
for (auto &cloned_parameter_node : root->parameters()) {
|
||||
MS_EXCEPTION_IF_NULL(cloned_parameter_node);
|
||||
auto cloned_parameter = cloned_parameter_node->cast<ParameterPtr>();
|
||||
|
@ -512,11 +514,20 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
|
|||
// from pipeline or grad accumulation
|
||||
if (param_name.find(ACCU_GRADS) != std::string::npos) {
|
||||
auto slice_shape = cloned_from_parameter->user_data<TensorLayout>()->slice_shape().array();
|
||||
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
|
||||
auto opt_shard_group = tensor_layout->opt_shard_group();
|
||||
auto opt_shard_shape = cloned_from_parameter->user_data<TensorLayout>()->opt_shard_slice_shape();
|
||||
std::shared_ptr<abstract::BaseShape> parallel_shape = nullptr;
|
||||
// set opt shard shape if the pipeline sharding is set
|
||||
if (grad_accumulation_shard && !opt_shard_group.empty()) {
|
||||
parallel_shape = std::make_shared<abstract::Shape>(opt_shard_shape);
|
||||
} else {
|
||||
parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(parallel_shape);
|
||||
cloned_abstract->set_shape(parallel_shape);
|
||||
// in opt shard, accu_grad's shape is different from the original param's shape
|
||||
if (ParallelContext::GetInstance()->enable_parallel_optimizer()) {
|
||||
// if the grad_accumulation_shard is enabled, the accu_grads will be a opt-sharded shape
|
||||
if (!grad_accumulation_shard && ParallelContext::GetInstance()->enable_parallel_optimizer()) {
|
||||
TensorLayout new_layout = *tensor_layout;
|
||||
new_layout.set_opt_shard_group("");
|
||||
tensor_layout = std::make_shared<TensorLayout>(new_layout);
|
||||
|
@ -526,6 +537,13 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
|
|||
}
|
||||
cloned_parameter->set_user_data<TensorLayout>(tensor_layout);
|
||||
cloned_parameter_node->set_abstract(cloned_abstract);
|
||||
// copy the fusion tag
|
||||
auto cloned_param_info = cloned_parameter->param_info();
|
||||
MS_EXCEPTION_IF_NULL(cloned_param_info);
|
||||
auto cloned_from_param_info = cloned_from_parameter->param_info();
|
||||
MS_EXCEPTION_IF_NULL(cloned_from_param_info);
|
||||
cloned_param_info->set_comm_fusion(cloned_from_param_info->comm_fusion());
|
||||
|
||||
MS_LOG(INFO) << "The parameter: " << cloned_parameter->name()
|
||||
<< " is cloned, the be cloned parameter is: " << cloned_from_parameter->name()
|
||||
<< ", clone index is: " << cloned_index;
|
||||
|
|
|
@ -64,7 +64,7 @@ static const std::set<std::string> NO_INPUT_TENSOR_OPS = {UNIFORM_REAL};
|
|||
// it will be one item in map with key: C, and value: (B, i)
|
||||
std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap;
|
||||
|
||||
void SetMiniStepOpDoMirrorLabel(std::vector<AnfNodePtr> new_node_input, bool accu_flag) {
|
||||
void SetMiniStepOpDoMirrorLabel(std::vector<AnfNodePtr> new_node_input, bool do_mirror, bool accu_flag) {
|
||||
if (new_node_input.empty()) {
|
||||
return;
|
||||
}
|
||||
|
@ -73,7 +73,8 @@ void SetMiniStepOpDoMirrorLabel(std::vector<AnfNodePtr> new_node_input, bool acc
|
|||
MS_EXCEPTION_IF_NULL(prim);
|
||||
|
||||
auto attrs = prim->attrs();
|
||||
attrs[DO_MIRROR] = MakeValue<bool>(!accu_flag);
|
||||
attrs[DO_MIRROR] = MakeValue<bool>(do_mirror);
|
||||
attrs[ADD_ACCU] = MakeValue<bool>(accu_flag);
|
||||
prim->SetAttrs(attrs);
|
||||
}
|
||||
|
||||
|
@ -189,7 +190,9 @@ std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operat
|
|||
SetCommunicationOpGroupLabel(new_node_input);
|
||||
// gradient accumulation
|
||||
if (grad_accumulation_step > 1) {
|
||||
SetMiniStepOpDoMirrorLabel(new_node_input, root->has_flag(ACCUMULATION));
|
||||
bool add_accu = root->has_flag(ACCUMULATION);
|
||||
// MiniStep need to do mirror at each micro step as we use the gradient accumulation sharding,
|
||||
SetMiniStepOpDoMirrorLabel(new_node_input, !add_accu, !add_accu);
|
||||
}
|
||||
return new_node_input;
|
||||
}
|
||||
|
@ -1510,6 +1513,7 @@ static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group
|
|||
const AnfNodePtr &node, const std::string &op_name, bool is_shared_param) {
|
||||
MS_EXCEPTION_IF_NULL(res.first);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
bool grad_accumulation_shard = ParallelContext::GetInstance()->grad_accumulation_shard();
|
||||
auto cnode = res.first->cast<CNodePtr>();
|
||||
auto graph = cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -1528,6 +1532,12 @@ static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group
|
|||
op = CreateAllGatherOp(group);
|
||||
}
|
||||
CNodePtr cast_node = InsertAllGatherAfterCast(cnode);
|
||||
std::string opt_shard_mirror_group;
|
||||
auto param_ptr = node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param_ptr);
|
||||
if (param_ptr->user_data<TensorLayout>()) {
|
||||
opt_shard_mirror_group = param_ptr->user_data<TensorLayout>()->opt_shard_mirror_group();
|
||||
}
|
||||
if (!is_shared_param && cast_node) {
|
||||
allgather = ReplaceNode(op, cast_node, graph, PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE, param_name, root);
|
||||
MS_LOG(INFO) << "Parallel optimizer is applied before Cast for " << param_name;
|
||||
|
@ -1541,6 +1551,17 @@ static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group
|
|||
AddCommOpFusionType(allgather, node);
|
||||
// add gradients mean
|
||||
AddCommOpMeanFlag(allgather);
|
||||
if (op_name == MICRO_STEP_ALL_GATHER) {
|
||||
// When grad_accumulation_shard is enabled, the ReduceScatter is inserted at each micro step
|
||||
// so no need to do backward for the micro_step_allgather
|
||||
AddCommOpMirrorFlag(allgather, !grad_accumulation_shard);
|
||||
} else if (op_name == MINI_STEP_ALL_GATHER) {
|
||||
// We need to manually set the add_accu to be false if it's father node is MirrorMiniStep
|
||||
bool add_accu = root->has_flag(ACCUMULATION);
|
||||
bool is_with_mirror = opt_shard_mirror_group.size() > 1;
|
||||
AddCommOpAddAccuFlag(allgather, !add_accu && !is_with_mirror);
|
||||
AddCommOpMirrorFlag(allgather, grad_accumulation_shard || !add_accu);
|
||||
}
|
||||
}
|
||||
|
||||
static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter,
|
||||
|
|
|
@ -140,6 +140,8 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("get_device_num_is_set", &ParallelContext::device_num_is_set, "Get device num is set.")
|
||||
.def("get_global_rank", &ParallelContext::global_rank, "Get global rank.")
|
||||
.def("set_global_rank", &ParallelContext::set_global_rank, "Set global rank.")
|
||||
.def("get_grad_accumulation_shard", &ParallelContext::grad_accumulation_shard, "Get grad_accumulation_shard.")
|
||||
.def("set_grad_accumulation_shard", &ParallelContext::set_grad_accumulation_shard, "Set grad_accumulation_shard.")
|
||||
.def("get_global_rank_is_set", &ParallelContext::global_rank_is_set, "Get global rank is set.")
|
||||
.def("get_gradients_mean", &ParallelContext::gradients_mean, "Get mirror mean.")
|
||||
.def("set_gradients_mean", &ParallelContext::set_gradients_mean, "Set mirror mean.")
|
||||
|
|
|
@ -503,10 +503,11 @@ CNodePtr AscendStreamAssign::GetTargetOutputNode(const vector<CNodePtr> &moved_b
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
for (; it < cnode_ptr_list.end() && AnfAlgo::GetGraphId((*it).get()) != subgraph_id; it++) {
|
||||
for (; it < cnode_ptr_list.end(); it++) {
|
||||
auto inputs = GetInputKernels(*it);
|
||||
for (auto &input : inputs) {
|
||||
if (find(moved_backward_cnodes.begin(), moved_backward_cnodes.end(), input) != moved_backward_cnodes.end()) {
|
||||
if (find(moved_backward_cnodes.begin(), moved_backward_cnodes.end(), input) != moved_backward_cnodes.end() &&
|
||||
AnfAlgo::GetGraphId((*it).get()) != subgraph_id) {
|
||||
MS_LOG(INFO) << "The nodes moved backward were used by nodes on different subgraphs, no need moved";
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -387,6 +387,7 @@ constexpr auto kAttrN = "n";
|
|||
constexpr auto kAttrLabelForInsertStreamActive = "label_for_insert_stream_active";
|
||||
constexpr auto kAttrFpBpEnd = "fpbp_end";
|
||||
constexpr auto kAttrFusion = "fusion";
|
||||
constexpr auto kAttrNotDelayFusion = "not_delay_fusion";
|
||||
constexpr auto kAttrGroup = "group";
|
||||
constexpr auto kAttrGroups = "groups";
|
||||
constexpr auto kAttrGroupBack = "group_back";
|
||||
|
|
|
@ -349,7 +349,8 @@ def _context():
|
|||
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str,
|
||||
auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
|
||||
all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int)
|
||||
all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int,
|
||||
parallel_optimizer_config=dict)
|
||||
def set_auto_parallel_context(**kwargs):
|
||||
r"""
|
||||
Set auto parallel context, which is valid only for Ascend and GPU target.
|
||||
|
@ -375,7 +376,7 @@ def set_auto_parallel_context(**kwargs):
|
|||
parallel_mode strategy_ckpt_load_file
|
||||
all_reduce_fusion_config strategy_ckpt_save_file
|
||||
enable_parallel_optimizer dataset_strategy
|
||||
\ pipeline_stages
|
||||
parallel_optimizer_config pipeline_stages
|
||||
\ grad_accumulation_step
|
||||
=========================== ===========================
|
||||
|
||||
|
@ -432,6 +433,21 @@ def set_auto_parallel_context(**kwargs):
|
|||
Default: 1.
|
||||
grad_accumulation_step (int): Set the accumulation steps of gradients in auto and semi auto parallel mode.
|
||||
This should be a positive int. Default: 1.
|
||||
parallel_optimizer_config (dict): A dict contains the keys and values for setting the parallel optimizer
|
||||
configure. The configure provides more detailed behavior control about parallel training
|
||||
when parallel optimizer is enabled. Currently it supports the key `gradient_accumulation_shard`.
|
||||
The configure will be effective when we use
|
||||
context.set_auto_parallel_context(enable_parallel_optimizer=True).
|
||||
It supports the following keys.
|
||||
|
||||
- gradient_accumulation_shard: If ture, the accumulation gradient parameters will be
|
||||
sharded across the data parallel devices. This will
|
||||
introduce additional communication(ReduceScatter) at
|
||||
each step when accumulate the gradients, but saves a
|
||||
lot of device memories, thus can make model be trained
|
||||
with larger batch size. This configure is effective only
|
||||
when the model runs on pipeline training or gradient
|
||||
accumulation with data parallel.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not attribute in auto parallel context.
|
||||
|
@ -451,6 +467,8 @@ def set_auto_parallel_context(**kwargs):
|
|||
>>> context.set_auto_parallel_context(enable_parallel_optimizer=False)
|
||||
>>> context.set_auto_parallel_context(all_reduce_fusion_config=[8, 160])
|
||||
>>> context.set_auto_parallel_context(pipeline_stages=2)
|
||||
>>> parallel_config = {"gradient_accumulation_shard": True}
|
||||
>>> context.set_auto_parallel_context(parallel_optimizer_config=parallel_config, enable_parallel_optimizer=True)
|
||||
"""
|
||||
_set_auto_parallel_context(**kwargs)
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ from mindspore import Tensor
|
|||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.communication import get_rank, get_group_size
|
||||
from mindspore.parallel._utils import _get_enable_parallel_optimizer
|
||||
from mindspore.parallel._utils import _get_enable_parallel_optimizer, _get_grad_accumulation_shard
|
||||
from .. import operations as P
|
||||
from ...common.tensor import RowTensor
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
|
@ -131,8 +131,22 @@ def get_bprop_virtual_assign_add(self):
|
|||
cast = P.Cast()
|
||||
dtype = P.DType()
|
||||
out_tensor = Tensor(0.0, mstype.float16)
|
||||
reduce_scatter = None
|
||||
group = self.get_attr_dict().get("group", None)
|
||||
fusion = self.get_attr_dict().get("fusion", 0)
|
||||
if group:
|
||||
reduce_scatter = ReduceScatter(ReduceOp.SUM, group).add_prim_attr("fusion", fusion)
|
||||
if self.instance_name:
|
||||
instance_name = "_grad_accumulation_shard_grad" + self.instance_name
|
||||
reduce_scatter.set_prim_instance_name(instance_name)
|
||||
# For pipeline training, as the fused communication will be visited later
|
||||
# this may make memory increase, so we need to add a tag to let the
|
||||
# fused communication not be effective.
|
||||
reduce_scatter.add_prim_attr("not_delay_fusion", True)
|
||||
|
||||
def bprop(x, y, out, dout):
|
||||
if reduce_scatter:
|
||||
dout = reduce_scatter(dout)
|
||||
temp = assign_add(y, dout)
|
||||
return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(y))), temp)
|
||||
|
||||
|
@ -237,8 +251,11 @@ def get_bprop_mini_step_all_gather(self):
|
|||
fusion = self.get_attr_dict()["fusion"]
|
||||
mean_flag = self.get_attr_dict()["mean_flag"]
|
||||
do_mirror = self.get_attr_dict()["do_mirror"]
|
||||
add_accu = self.get_attr_dict().get("add_accu", False)
|
||||
gradient_shard = _get_grad_accumulation_shard()
|
||||
scale = 1 / self.rank_size
|
||||
all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion)
|
||||
assign_add = P.AssignAdd()
|
||||
if self.instance_name:
|
||||
instance_name = "grad_" + self.instance_name
|
||||
all_reduce.set_prim_instance_name(instance_name)
|
||||
|
@ -248,15 +265,21 @@ def get_bprop_mini_step_all_gather(self):
|
|||
|
||||
def bprop(x, z, out, dout):
|
||||
if do_mirror:
|
||||
if mean_flag:
|
||||
if not gradient_shard:
|
||||
z = F.depend(z, F.assign_add(z, dout))
|
||||
grad = all_reduce(z)
|
||||
dx = split(grad)[rank]
|
||||
if mean_flag:
|
||||
dx = F.tensor_mul(dx, scale)
|
||||
else:
|
||||
z = F.depend(z, F.assign_add(z, dout))
|
||||
grad = all_reduce(z)
|
||||
dout = F.depend(dout, z)
|
||||
grad = all_reduce(dout)
|
||||
dx = split(grad)[rank]
|
||||
if mean_flag:
|
||||
dx = F.tensor_mul(dx, scale)
|
||||
if add_accu:
|
||||
z = assign_add(z, dx)
|
||||
dx = F.depend(dx, z)
|
||||
else:
|
||||
dx = dout
|
||||
return (dx, zeros_like(z))
|
||||
|
@ -269,6 +292,7 @@ def get_bprop_micro_step_all_gather(self):
|
|||
"""Generate bprop for _MicroStepAllGather"""
|
||||
fusion = self.get_attr_dict()["fusion"]
|
||||
mean_flag = self.get_attr_dict()["mean_flag"]
|
||||
do_mirror = self.get_attr_dict()["do_mirror"]
|
||||
scale = 1 / self.rank_size
|
||||
all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion)
|
||||
rank = get_rank(self.group)
|
||||
|
@ -284,6 +308,8 @@ def get_bprop_micro_step_all_gather(self):
|
|||
# z: accu_grad
|
||||
def bprop(x, z, out, dout):
|
||||
z = F.depend(z, dout)
|
||||
if not do_mirror:
|
||||
return (z, cast(out_tensor, dtype(z)))
|
||||
real_grad = all_reduce(z)
|
||||
real_grad = split(real_grad)[rank]
|
||||
if mean_flag:
|
||||
|
|
|
@ -5,4 +5,4 @@ e
|
|||
bprop.8:x*
|
||||
bprop.8:out*
|
||||
bprop.8:dout2
|
||||
bprop.8:[CNode]:1:@74787be4234cdeb03f214519cd8358a5f4ad2f5606dbeb494462cddc448eb4beP
|
||||
bprop.8:[CNode]:1:@96c75d48466ae9dd2ae51ee64181426e1bf1c36337f7c6cf3bdd01083bfb1a6eP
|
|
@ -299,6 +299,7 @@ class _MicroStepAllGather(PrimitiveWithInfer):
|
|||
self.add_prim_attr('rank_size', self.rank_size)
|
||||
self.add_prim_attr('group', _get_group(group))
|
||||
self.add_prim_attr('fusion', 1)
|
||||
self.add_prim_attr('do_mirror', False)
|
||||
self.mean_flag = mean_flag
|
||||
|
||||
def infer_shape(self, x_shape, z_shape):
|
||||
|
|
|
@ -13,9 +13,9 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Context of auto parallel"""
|
||||
import os
|
||||
import threading
|
||||
|
||||
import mindspore.context as context
|
||||
from mindspore import context
|
||||
import mindspore.log as logger
|
||||
from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
|
||||
from mindspore.parallel._ps_context import _is_role_pserver
|
||||
|
@ -27,6 +27,13 @@ _DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1"
|
|||
_DEFAULT_NCCL_FUSION_GROUP_NAME = "nccl_world_groupsum1"
|
||||
|
||||
|
||||
class _ParallelOptimizerConfig:
|
||||
"""
|
||||
The key of the Parallel Optimizer. There are three
|
||||
"""
|
||||
GRADIENT_ACCUMULATION_SHARD = "gradient_accumulation_shard"
|
||||
|
||||
|
||||
class _AutoParallelContext:
|
||||
"""
|
||||
_AutoParallelContext is the environment in which operations are executed
|
||||
|
@ -326,7 +333,6 @@ class _AutoParallelContext:
|
|||
strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
import os
|
||||
dir_path = os.path.dirname(strategy_ckpt_save_file)
|
||||
if dir_path and not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
|
@ -340,7 +346,6 @@ class _AutoParallelContext:
|
|||
def set_group_ckpt_save_file(self, group_ckpt_save_file):
|
||||
"""Set group checkpoint save path."""
|
||||
self.check_context_handle()
|
||||
import os
|
||||
dir_path = os.path.dirname(group_ckpt_save_file)
|
||||
if dir_path and not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
|
@ -489,6 +494,41 @@ class _AutoParallelContext:
|
|||
self.check_context_handle()
|
||||
return self._context_handle.get_enable_parallel_optimizer()
|
||||
|
||||
def set_parallel_optimizer_config(self, parallel_optimizer_config):
|
||||
"""
|
||||
Set the configure for parallel optimizer. The configure provides more detailed behavior control about parallel
|
||||
training when parallel optimizer is enabled.
|
||||
Currently it supports the key `gradient_accumulation_shard`. The configure will be effective
|
||||
when we use context.set_auto_parallel_context(enable_parallel_optimizer=True).
|
||||
|
||||
Args:
|
||||
parallel_optimizer_config(dict): A dict contains the keys and values for setting the parallel optimizer
|
||||
configure. It supports the following keys:
|
||||
|
||||
- gradient_accumulation_shard: If ture, the accumulation gradient parameters will be sharded
|
||||
across the data parallel devices. This will introduce additional
|
||||
communication(ReduceScatter) at each step when accumulate the
|
||||
gradients, but saves a lot of device memories,
|
||||
thus can make model be trained with larger batch size.
|
||||
This configure is effective only when the model runs on pipeline
|
||||
training or gradient accumulation with data parallel.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
grad_shard_name = _ParallelOptimizerConfig.GRADIENT_ACCUMULATION_SHARD
|
||||
if grad_shard_name in parallel_optimizer_config:
|
||||
Validator.check_bool(
|
||||
parallel_optimizer_config[grad_shard_name], grad_shard_name, grad_shard_name)
|
||||
self._context_handle.set_grad_accumulation_shard(
|
||||
parallel_optimizer_config[grad_shard_name])
|
||||
else:
|
||||
raise ValueError(f"The parallel_optimizer_config doest not contains {grad_shard_name}, please check your "
|
||||
f"parallel_optimizer_config")
|
||||
|
||||
|
||||
def get_grad_accumulation_shard(self):
|
||||
self.check_context_handle()
|
||||
return self._context_handle.get_grad_accumulation_shard()
|
||||
|
||||
def set_sharding_propagation(self, sharding_propagation):
|
||||
"""
|
||||
Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True, the strategy-configured operators
|
||||
|
@ -648,6 +688,7 @@ _set_auto_parallel_context_func_map = {
|
|||
"full_batch": auto_parallel_context().set_full_batch,
|
||||
"dataset_strategy": auto_parallel_context().set_dataset_strategy,
|
||||
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
|
||||
"parallel_optimizer_config": auto_parallel_context().set_parallel_optimizer_config,
|
||||
"grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step,
|
||||
"all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices,
|
||||
"communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode,
|
||||
|
@ -802,5 +843,6 @@ def _reset_auto_parallel_context():
|
|||
- enable_parallel_optimizer: False
|
||||
- auto_parallel_search_mode: dynamic_programming
|
||||
- pipeline_stages: 0
|
||||
- gradient_accumulation_shard: True
|
||||
"""
|
||||
auto_parallel_context().reset()
|
||||
|
|
|
@ -219,6 +219,11 @@ def _get_enable_parallel_optimizer():
|
|||
return auto_parallel_context().get_enable_parallel_optimizer()
|
||||
|
||||
|
||||
def _get_grad_accumulation_shard():
|
||||
"""Get if using parallel shard."""
|
||||
return auto_parallel_context().get_grad_accumulation_shard()
|
||||
|
||||
|
||||
def _device_number_check(parallel_mode, device_number):
|
||||
"""
|
||||
Check device num.
|
||||
|
|
|
@ -18,6 +18,7 @@ import mindspore.common.dtype as mstype
|
|||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.context import set_auto_parallel_context, ParallelMode
|
||||
from mindspore import context
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
import mindspore.ops as P
|
||||
|
@ -250,11 +251,18 @@ def test_transformer_model_auto_parallel_no_support():
|
|||
mode=ParallelMode.AUTO_PARALLEL)
|
||||
|
||||
|
||||
def test_pipeline_single_transformer():
|
||||
def pipeline_single_transformer(grad_accumulation_shard=False):
|
||||
"""
|
||||
Feature: Gradient Accumulation Shard for Pipeline and Gradient Accumulation
|
||||
Description: Test a single transformer model with pipeline parallel with grad_accumulation_shard False
|
||||
Expectation: The compile passed
|
||||
"""
|
||||
set_auto_parallel_context(device_num=32,
|
||||
full_batch=True,
|
||||
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0,
|
||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
context.set_auto_parallel_context(parallel_optimizer_config=
|
||||
{"gradient_accumulation_shard": grad_accumulation_shard})
|
||||
|
||||
net = Transformer(batch_size=4 // pipeline_config.micro_batch_num,
|
||||
src_seq_length=20,
|
||||
|
@ -286,6 +294,24 @@ def test_pipeline_single_transformer():
|
|||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
||||
|
||||
def test_pipeline_transformer_gradient_shard_true():
|
||||
"""
|
||||
Feature: Gradient Accumulation Shard for Pipeline and Gradient Accumulation
|
||||
Description: Test a single transformer model with pipeline parallel with grad_accumulation_shard True
|
||||
Expectation: The compile passed
|
||||
"""
|
||||
pipeline_single_transformer(grad_accumulation_shard=True)
|
||||
|
||||
|
||||
def test_pipeline_transformer_gradient_shard_false():
|
||||
"""
|
||||
Feature: Gradient Accumulation Shard for Pipeline and Gradient Accumulation
|
||||
Description: Test a single transformer model with pipeline parallel with grad_accumulation_shard False
|
||||
Expectation: The compile passed
|
||||
"""
|
||||
pipeline_single_transformer(grad_accumulation_shard=False)
|
||||
|
||||
|
||||
def test_transformer_wrong_head():
|
||||
set_auto_parallel_context(device_num=32,
|
||||
full_batch=True,
|
||||
|
|
Loading…
Reference in New Issue