forked from mindspore-Ecosystem/mindspore
!12853 handle the fully split parameter for grad accumulation
From: @yangzhenzhang Reviewed-by: Signed-off-by:
This commit is contained in:
commit
c12abe7a46
|
@ -51,7 +51,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
static std::shared_ptr<std::map<ValuePtr, ParameterPtr>> python_paras;
|
||||
static std::shared_ptr<std::map<ParamInfoPtr, ParameterPtr>> python_paras;
|
||||
void ClearPythonParasMap() { python_paras = nullptr; }
|
||||
namespace {
|
||||
const int kSummaryGetItem = 2;
|
||||
|
@ -106,7 +106,7 @@ bool CheckIfNeedCreateOutputTensor(const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
|
||||
ValuePtr GetParamDefaultValue(const AnfNodePtr &node) {
|
||||
ParamInfoPtr GetParamDefaultValue(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -114,7 +114,7 @@ ValuePtr GetParamDefaultValue(const AnfNodePtr &node) {
|
|||
if (parameter == nullptr || !parameter->has_default()) {
|
||||
return nullptr;
|
||||
}
|
||||
return parameter->default_param();
|
||||
return parameter->param_info();
|
||||
}
|
||||
|
||||
tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_output_pair,
|
||||
|
@ -747,7 +747,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
|
|||
ParameterPtr new_parameter = nullptr;
|
||||
// if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
|
||||
if (python_paras == nullptr) {
|
||||
python_paras = std::make_shared<std::map<ValuePtr, ParameterPtr>>();
|
||||
python_paras = std::make_shared<std::map<ParamInfoPtr, ParameterPtr>>();
|
||||
}
|
||||
auto iter = python_paras->find(param_value);
|
||||
if (iter != python_paras->end()) {
|
||||
|
@ -1217,7 +1217,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
|
|||
auto param_value = GetParamDefaultValue(anf);
|
||||
ParameterPtr new_parameter = nullptr;
|
||||
if (python_paras == nullptr) {
|
||||
python_paras = std::make_shared<std::map<ValuePtr, ParameterPtr>>();
|
||||
python_paras = std::make_shared<std::map<ParamInfoPtr, ParameterPtr>>();
|
||||
}
|
||||
auto iter = python_paras->find(param_value);
|
||||
if (iter != python_paras->end()) {
|
||||
|
|
|
@ -88,6 +88,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
prim::kPrimMirrorMiniStep);
|
||||
mini_step_allgather_replace_ = MakeSubstitution(std::make_shared<MiniStepAllGatherPass>(),
|
||||
"mini_step_allgather_replace", prim::kPrimMiniStepAllGather);
|
||||
virtual_add_elim_ = MakeSubstitution(std::make_shared<VirtualAddEliminater>(), "virtual add", prim::kPrimVirtualAdd);
|
||||
check_bprop_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop);
|
||||
reset_defer_inline_ =
|
||||
|
|
|
@ -52,6 +52,7 @@ class OptimizeIRPassLib {
|
|||
SubstitutionPtr depend_value_elim_;
|
||||
SubstitutionPtr all_reduce_const_elim_;
|
||||
SubstitutionPtr mirror_mini_step_elim_;
|
||||
SubstitutionPtr virtual_add_elim_;
|
||||
SubstitutionPtr mini_step_allgather_replace_;
|
||||
|
||||
// Env Item Eliminate
|
||||
|
|
|
@ -175,6 +175,25 @@ class MirrorMiniStepEliminater : public AnfVisitor {
|
|||
void Visit(const AnfNodePtr &) override {}
|
||||
};
|
||||
|
||||
// {prim::kPrimVirtualAdd, X, Z} -> X
|
||||
class VirtualAddEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimVirtualAdd) || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
if (inputs.size() < 2) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return inputs[1];
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &) override {}
|
||||
};
|
||||
|
||||
// {prim::kPrimMiniStepAllGather, X, Z} -> {prim::kPrimAllGather, X}
|
||||
class MiniStepAllGatherPass : public AnfVisitor {
|
||||
public:
|
||||
|
@ -191,8 +210,15 @@ class MiniStepAllGatherPass : public AnfVisitor {
|
|||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto attrs = prim->attrs();
|
||||
std::string group = attrs[parallel::GROUP]->ToString();
|
||||
auto fusion = attrs[parallel::FUSION];
|
||||
parallel::Operator op = parallel::CreateAllGatherOp(group);
|
||||
std::vector<AnfNodePtr> node_input = parallel::CreateInput(op, inputs[1], parallel::PARALLEL_OPTIMIZER_ALLGATHER);
|
||||
auto prim_anf_node = node_input[0]->cast<ValueNodePtr>();
|
||||
prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
attrs = prim->attrs();
|
||||
attrs[parallel::FUSION] = fusion;
|
||||
prim->SetAttrs(attrs);
|
||||
auto func_graph = inputs[1]->func_graph();
|
||||
CNodePtr new_node = func_graph->NewCNode(node_input);
|
||||
return new_node;
|
||||
|
|
|
@ -155,13 +155,23 @@ const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const
|
|||
// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
|
||||
void ParallelContext::ParallelParameterContextInitShape(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
if (func_graph->has_flag(AUTO_PARALLEL) &&
|
||||
(!func_graph->has_flag(TRAINING) ||
|
||||
(ParallelContext::GetInstance()->grad_accumulation_step() > 1 && !func_graph->has_flag(ACCUMULATION)))) {
|
||||
if (!func_graph->has_flag(AUTO_PARALLEL)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!func_graph->has_flag(TRAINING)) {
|
||||
init_param_shape_ = false;
|
||||
MS_LOG(INFO) << "In parallel evaluation or prediction, may be need to restore the parameter shape";
|
||||
return;
|
||||
}
|
||||
|
||||
if ((ParallelContext::GetInstance()->grad_accumulation_step() > 1) && !func_graph->has_flag(ACCUMULATION)) {
|
||||
init_param_shape_ = false;
|
||||
MS_LOG(INFO) << "In parallel grad accumulation second graph, need to restore the parameter shape";
|
||||
} else {
|
||||
param_shapes.clear();
|
||||
init_param_shape_ = true;
|
||||
MS_LOG(INFO) << "Init the parameter shape dict";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -171,6 +181,10 @@ void ParallelContext::ParallelParameterContextRestoreShape(const FuncGraphPtr &f
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
MS_EXCEPTION_IF_NULL(ptr);
|
||||
if (!func_graph->has_flag(AUTO_PARALLEL)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (init_param_shape_) {
|
||||
return;
|
||||
}
|
||||
|
@ -182,7 +196,7 @@ void ParallelContext::ParallelParameterContextRestoreShape(const FuncGraphPtr &f
|
|||
Shape shape = iter->second;
|
||||
std::shared_ptr<abstract::BaseShape> base_shape = std::make_shared<abstract::Shape>(shape);
|
||||
ptr->set_shape(base_shape);
|
||||
MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
|
||||
MS_LOG(INFO) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
|
||||
}
|
||||
|
||||
// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
|
||||
|
@ -192,6 +206,10 @@ void ParallelContext::ParallelParameterContextCkptShape(const FuncGraphPtr &func
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
MS_EXCEPTION_IF_NULL(ptr);
|
||||
if (!func_graph->has_flag(AUTO_PARALLEL)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!init_param_shape_) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -110,6 +110,8 @@ constexpr char STRIDES[] = "strides";
|
|||
constexpr char GROUP[] = "group";
|
||||
constexpr char FUSION[] = "fusion";
|
||||
constexpr char DO_MIRROR[] = "do_mirror";
|
||||
constexpr char RECOMPUTE[] = "recompute";
|
||||
constexpr char RECOMPUTE_COMM_OP[] = "recompute_comm_op";
|
||||
constexpr char NUM_SAMPLED[] = "num_sampled";
|
||||
constexpr char NUM_TRUE[] = "num_true";
|
||||
constexpr char SEED[] = "seed";
|
||||
|
|
|
@ -97,6 +97,27 @@ void SetMiniStepOpDoMirrorLabel(std::vector<AnfNodePtr> new_node_input, bool acc
|
|||
prim->SetAttrs(attrs);
|
||||
}
|
||||
|
||||
void SetAllReduceRecomputeFlag(const std::vector<AnfNodePtr> &new_node_input, const CNodePtr &node) {
|
||||
if (new_node_input.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
|
||||
auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto attrs = prim->attrs();
|
||||
|
||||
auto anf_node = node->input(0)->cast<ValueNodePtr>();
|
||||
auto prim_node = GetValueNode<PrimitivePtr>(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(prim_node);
|
||||
auto node_attrs = prim_node->attrs();
|
||||
if (node_attrs.find(RECOMPUTE_COMM_OP) != node_attrs.end() && !GetValue<bool>(node_attrs[RECOMPUTE_COMM_OP])) {
|
||||
attrs[RECOMPUTE] = MakeValue<bool>(false);
|
||||
prim->SetAttrs(attrs);
|
||||
MS_LOG(INFO) << "Do not recompute the forward communication operator of " << prim_node->ToString();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
OperatorArgs arg_forward = op.second;
|
||||
|
@ -353,6 +374,7 @@ void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) {
|
|||
std::string instance_name_base = FORWARD_OP;
|
||||
std::string instance_name = instance_name_base + "_" + CreateInstanceName(node, index);
|
||||
std::vector<AnfNodePtr> forward_input = CreateInput(forward_op[index], node_to_insert, instance_name);
|
||||
SetAllReduceRecomputeFlag(forward_input, node_to_insert);
|
||||
CNodePtr forward_node = func_graph->NewCNode(forward_input); // using NewCNode to create anfnode
|
||||
MS_EXCEPTION_IF_NULL(forward_node);
|
||||
ScopePtr scope = node->scope();
|
||||
|
@ -1165,7 +1187,14 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
|
|||
|
||||
// not a RefKey
|
||||
if (!param_node_pair.second) {
|
||||
auto next_cnode = FindCNode(param_node_pair.first, MIRROR_OPERATOR, func_graph);
|
||||
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
||||
std::string mirror_op_name;
|
||||
if (grad_accumulation_step > 1) {
|
||||
mirror_op_name = MIRROR_MINI_STEP_OPERATOR;
|
||||
} else {
|
||||
mirror_op_name = MIRROR_OPERATOR;
|
||||
}
|
||||
auto next_cnode = FindCNode(param_node_pair.first, mirror_op_name, func_graph);
|
||||
// if there is already a MirrorOp in the same graph, use MirrorOp CNode as a input instead
|
||||
if (next_cnode.first) {
|
||||
MS_EXCEPTION_IF_NULL(next_cnode.second);
|
||||
|
@ -1743,6 +1772,10 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
|
|||
if (found_be_cloned_parameter) {
|
||||
// set the shape and tensor layout for cloned parameter
|
||||
std::string param_name = cloned_parameter_node->cast<ParameterPtr>()->name();
|
||||
if (cloned_from_parameter->user_data<TensorLayout>() == nullptr) {
|
||||
MS_LOG(WARNING) << "The parameter " << param_name << " has not tensor layout, skip it";
|
||||
continue;
|
||||
}
|
||||
cloned_parameter->set_user_data<TensorLayout>(cloned_from_parameter->user_data<TensorLayout>());
|
||||
MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract());
|
||||
MS_EXCEPTION_IF_NULL(cloned_from_node->abstract());
|
||||
|
@ -3298,6 +3331,97 @@ static void HandleNoUsedParameter(const FuncGraphPtr &root) {
|
|||
}
|
||||
}
|
||||
|
||||
static bool IsFullySplitParameter(const ParameterPtr ¶m_ptr) {
|
||||
auto tensor_layout = param_ptr->user_data<parallel::TensorLayout>();
|
||||
if (tensor_layout == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto dev_mat_shape = tensor_layout->device_arrangement().array();
|
||||
auto tensor_map = tensor_layout->tensor_map().array();
|
||||
int64_t rank = g_device_manager->global_rank();
|
||||
RankList rank_list = g_device_manager->GetDeviceListInThisStage();
|
||||
DeviceMatrix dev_matrix(rank, rank_list, dev_mat_shape);
|
||||
RankList group_devices;
|
||||
if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
|
||||
MS_LOG(WARNING) << "Get devices by tensor map failed, invalid tensor layout";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (group_devices.size() == 1) {
|
||||
MS_LOG(INFO) << "The parameter: " << param_ptr->name() << " is fully split";
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static AnfNodePtr FindGradAccuParameter(const std::vector<AnfNodePtr> ¶meters, const std::string &name) {
|
||||
for (auto ¶meter : parameters) {
|
||||
auto param_ptr = parameter->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param_ptr);
|
||||
if (param_ptr->name() == name) {
|
||||
continue;
|
||||
}
|
||||
if (param_ptr->name().find(name) != std::string::npos && param_ptr->name().find("accu_grad") != std::string::npos) {
|
||||
return parameter;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static void InsertFullySplitParamGradAccu(const std::pair<AnfNodePtr, int> &node_user,
|
||||
const FuncGraphManagerPtr &manager, const AnfNodePtr &accu_parameter) {
|
||||
auto cnode = node_user.first->cast<CNodePtr>();
|
||||
auto prim = GetCNodePrimitive(cnode);
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(WARNING) << cnode->DebugString() << " can not insert fully split param grad accumulation node";
|
||||
return;
|
||||
}
|
||||
OperatorAttrs attrs;
|
||||
auto py_instance = CreatOpInstance(attrs, "_VirtualAdd", "grad_accu");
|
||||
auto value_node = NewValueNode(py_instance);
|
||||
std::vector<AnfNodePtr> virtual_node_input = {value_node, cnode->input(node_user.second), accu_parameter};
|
||||
auto graph = cnode->func_graph();
|
||||
auto virtual_node = graph->NewCNode(virtual_node_input);
|
||||
manager->SetEdge(cnode, node_user.second, virtual_node);
|
||||
}
|
||||
|
||||
static void HandleFullySplitParameters(const FuncGraphPtr &root) {
|
||||
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
||||
if ((grad_accumulation_step <= 1) || root->has_flag(ACCUMULATION)) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto parameters = root->parameters();
|
||||
auto node_users_map = root->manager()->node_users();
|
||||
for (auto ¶meter : parameters) {
|
||||
auto param_ptr = parameter->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param_ptr);
|
||||
|
||||
if (!IsFullySplitParameter(param_ptr)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto accu_parameter = FindGradAccuParameter(parameters, param_ptr->name());
|
||||
if (!accu_parameter) {
|
||||
continue; // some parameters no need to handle, such as itself or lr
|
||||
}
|
||||
|
||||
auto node_users = node_users_map[parameter];
|
||||
for (auto &user : node_users) {
|
||||
auto node = user.first;
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!cnode->in_forward_flag()) {
|
||||
continue;
|
||||
}
|
||||
InsertFullySplitParamGradAccu(user, root->manager(), accu_parameter);
|
||||
MS_LOG(INFO) << "Insert full split assign add node for " << param_ptr->name();
|
||||
break; // only need to insert once, if the parameter has many users
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
|
||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||
if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
|
||||
|
@ -3390,6 +3514,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
|||
MS_LOG(EXCEPTION) << "Save group info failed";
|
||||
}
|
||||
|
||||
// handle full split parammeters in grad accumulation, do not contain optimizer-sharding's parameter
|
||||
HandleFullySplitParameters(root);
|
||||
|
||||
DumpGraph(root, std::string(STEP_PARALLEL_END));
|
||||
|
||||
// step parallel only run once
|
||||
|
|
|
@ -159,6 +159,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
irpass.switch_layer_defer_inline_,
|
||||
irpass.replace_applicator_,
|
||||
irpass.mirror_mini_step_elim_,
|
||||
irpass.virtual_add_elim_,
|
||||
irpass.row_tensor_add_zeros_like_,
|
||||
irpass.mini_step_allgather_replace_,
|
||||
});
|
||||
|
|
|
@ -307,6 +307,7 @@ inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOper
|
|||
inline const PrimitivePtr kPrimMirrorMiniStep = std::make_shared<Primitive>("_MirrorMiniStepOperator");
|
||||
inline const PrimitivePtr kPrimMiniStepAllGather = std::make_shared<Primitive>("_MiniStepAllGather");
|
||||
inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
|
||||
inline const PrimitivePtr kPrimVirtualAdd = std::make_shared<Primitive>("_VirtualAdd");
|
||||
inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
|
||||
inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send");
|
||||
inline const PrimitivePtr kPrimReceive = std::make_shared<Primitive>("Receive");
|
||||
|
|
|
@ -22,7 +22,7 @@ from ...common.tensor import RowTensor
|
|||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
|
||||
_GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp,
|
||||
ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap)
|
||||
ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, AllSwap)
|
||||
from .grad_base import bprop_getters
|
||||
from ..operations._inner_ops import Send, Receive
|
||||
|
||||
|
@ -108,6 +108,14 @@ def get_bprop_receive(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(_VirtualAdd)
|
||||
def get_bprop_virtual_add(self):
|
||||
"""Generate bprop for _VirtualAdd"""
|
||||
def bprop(x, grad_accu, out, dout):
|
||||
return (dout + grad_accu, zeros_like(grad_accu))
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(Broadcast)
|
||||
def get_bprop_broad_cast(self):
|
||||
"""Generate bprop for Broadcast."""
|
||||
|
@ -168,13 +176,13 @@ def get_bprop_mini_step_all_gather(self):
|
|||
def bprop(x, z, out, dout):
|
||||
if do_mirror:
|
||||
if mean_flag:
|
||||
tmp = z + dout
|
||||
grad = all_reduce(tmp)
|
||||
z = F.depend(z, F.assign_add(z, dout))
|
||||
grad = all_reduce(z)
|
||||
dx = split(grad)[rank]
|
||||
dx = F.tensor_mul(dx, scale)
|
||||
else:
|
||||
tmp = z + dout
|
||||
grad = all_reduce(tmp)
|
||||
z = F.depend(z, F.assign_add(z, dout))
|
||||
grad = all_reduce(z)
|
||||
dx = split(grad)[rank]
|
||||
else:
|
||||
dx = dout
|
||||
|
@ -326,7 +334,6 @@ def get_bprop_mirror_mini_step_operator(self):
|
|||
mean_flag = self.mean_flag
|
||||
|
||||
all_reduce = AllReduce(group=group)
|
||||
all_gather = AllGather(group=group)
|
||||
mul = P.Mul()
|
||||
cast = P.Cast()
|
||||
|
||||
|
@ -345,8 +352,8 @@ def get_bprop_mirror_mini_step_operator(self):
|
|||
if mean_flag:
|
||||
if F.issubclass_(F.typeof(dout), mstype.tensor):
|
||||
if do_mirror:
|
||||
tmp = z + dout
|
||||
real_grad = all_reduce(tmp)
|
||||
z = F.depend(z, F.assign_add(z, dout))
|
||||
real_grad = all_reduce(z)
|
||||
dx = real_grad
|
||||
else:
|
||||
dx = dout
|
||||
|
@ -354,32 +361,17 @@ def get_bprop_mirror_mini_step_operator(self):
|
|||
num = F.scalar_cast(dev_num, F.dtype(dx))
|
||||
dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx)))
|
||||
else:
|
||||
if do_mirror:
|
||||
indices = all_gather(dout.indices)
|
||||
grad = all_gather(dout.values)
|
||||
else:
|
||||
indices = dout.indices
|
||||
grad = dout.values
|
||||
float_one = F.scalar_cast(1.0, F.dtype(grad))
|
||||
num = F.scalar_cast(dev_num, F.dtype(grad))
|
||||
grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad)))
|
||||
dx = RowTensor(indices, grad, dout.dense_shape)
|
||||
dx = zeros_like(x) # The grad accumulation do not support row tensor now
|
||||
else:
|
||||
if F.issubclass_(F.typeof(dout), mstype.tensor):
|
||||
if do_mirror:
|
||||
tmp = z + dout
|
||||
real_grad = all_reduce(tmp)
|
||||
z = F.depend(z, F.assign_add(z, dout))
|
||||
real_grad = all_reduce(z)
|
||||
dx = real_grad
|
||||
else:
|
||||
dx = dout
|
||||
else:
|
||||
if do_mirror:
|
||||
indices = all_gather(dout.indices)
|
||||
grad = all_gather(dout.values)
|
||||
else:
|
||||
indices = dout.indices
|
||||
grad = dout.values
|
||||
dx = RowTensor(indices, grad, dout.dense_shape)
|
||||
dx = zeros_like(x) # The grad accumulation do not support row tensor now
|
||||
|
||||
return (dx, zeros_like(z))
|
||||
return bprop
|
||||
|
|
|
@ -36,7 +36,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
|
|||
Unique, GatherD, Identity, Range)
|
||||
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
|
||||
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
|
||||
_VirtualDiv, _GetTensorSlice,
|
||||
_VirtualDiv, _GetTensorSlice, _VirtualAdd,
|
||||
_HostAllGather, _HostReduceScatter)
|
||||
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
|
||||
TensorSummary, HistogramSummary, Print, Assert)
|
||||
|
|
|
@ -653,6 +653,19 @@ class _VirtualDiv(PrimitiveWithInfer):
|
|||
virtual_div = _VirtualDiv()
|
||||
|
||||
|
||||
class _VirtualAdd(PrimitiveWithInfer):
|
||||
"""Auto parallel virtual operator. Do nothing in forward, do Add in backward."""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init"""
|
||||
|
||||
def infer_shape(self, x_shape, y_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, y_dtype):
|
||||
return x_dtype
|
||||
|
||||
|
||||
class _VirtualDataset(PrimitiveWithInfer):
|
||||
"""
|
||||
Auto parallel virtual dataset operator.
|
||||
|
|
|
@ -25,6 +25,7 @@ from mindspore.common.initializer import TruncatedNormal, initializer, Normal
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
|
||||
class LayerNorm(nn.Cell):
|
||||
"""
|
||||
Layer Normalization
|
||||
|
|
|
@ -47,6 +47,7 @@ def test_get_parameter_layout():
|
|||
|
||||
net = Net(strategy1, strategy2, weight)
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
exe = me._executor
|
||||
exe.compile(net, x, phase='train', auto_parallel_mode=True)
|
||||
x_layout = ([2, 4], [1, -1], [16, 32], 0, True, '') # device_arrangement = [2, 4], tensor_map = [1, -1]
|
||||
|
|
|
@ -1,307 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.train import Model
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn import DistributedGradReducer, DynamicLossScaleUpdateCell, Cell, Momentum, Norm
|
||||
from mindspore.parallel._utils import _get_device_num
|
||||
from tests.dataset_mock import MindData
|
||||
|
||||
|
||||
class Dataset(MindData):
|
||||
def __init__(self, predict, label, length=3):
|
||||
super(Dataset, self).__init__(size=length)
|
||||
self.predict = predict
|
||||
self.label = label
|
||||
self.index = 0
|
||||
self.length = length
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.index >= self.length:
|
||||
raise StopIteration
|
||||
self.index += 1
|
||||
return self.predict, self.label
|
||||
|
||||
def reset(self):
|
||||
self.index = 0
|
||||
|
||||
|
||||
get_square_sum = C.MultitypeFuncGraph("get_square_sum")
|
||||
@get_square_sum.register("Tensor")
|
||||
def _get_square_sum(grad):
|
||||
norm = P.ReduceSum(False)(F.square(grad), ())
|
||||
norm = F.expand_dims(F.cast(norm, mstype.float32), 0)
|
||||
return norm
|
||||
|
||||
|
||||
apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
|
||||
@apply_global_norm.register("Tensor", "Tensor", "Tensor")
|
||||
def _apply_global_norm(clip_norm, global_norm, grad):
|
||||
grad = grad * clip_norm / global_norm
|
||||
return grad
|
||||
|
||||
|
||||
class GlobalNorm(Cell):
|
||||
"""
|
||||
Calculate the global norm value of given tensors
|
||||
"""
|
||||
def __init__(self):
|
||||
super(GlobalNorm, self).__init__()
|
||||
self.norm = Norm()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, grads):
|
||||
square_sum = self.hyper_map(get_square_sum, grads)
|
||||
global_norms = F.sqrt(F.addn(square_sum) / F.scalar_to_array(len(square_sum)))
|
||||
return global_norms
|
||||
|
||||
|
||||
class ClipByGlobalNorm(Cell):
|
||||
"""
|
||||
Clip grads by global norm
|
||||
"""
|
||||
def __init__(self, clip_norm=1.0):
|
||||
super(ClipByGlobalNorm, self).__init__()
|
||||
self.global_norm = GlobalNorm()
|
||||
self.clip_norm = Tensor([clip_norm], mstype.float32)
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, grads):
|
||||
global_norm = self.global_norm(grads)
|
||||
cond = P.GreaterEqual()(global_norm, self.clip_norm)
|
||||
global_norm = F.select(cond, global_norm, self.clip_norm)
|
||||
grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads)
|
||||
return grads
|
||||
|
||||
|
||||
cast = P.Cast()
|
||||
update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")
|
||||
|
||||
|
||||
@update_accu_grads.register("Tensor", "Tensor")
|
||||
def _update_accu_grads(accu_grad, grad):
|
||||
succ = True
|
||||
return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32)))
|
||||
|
||||
|
||||
zeroslike = P.ZerosLike()
|
||||
reset_accu_grads = C.MultitypeFuncGraph("reset_accu_grads")
|
||||
|
||||
|
||||
@reset_accu_grads.register("Tensor")
|
||||
def _reset_accu_grads(accu_grad):
|
||||
succ = True
|
||||
return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad)))
|
||||
|
||||
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
|
||||
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * reciprocal(scale)
|
||||
|
||||
|
||||
class TrainAccumulateStepsWithLossScaleCell(Cell):
|
||||
"""
|
||||
Encapsulation class of bert network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph. To mimic higher batch size, gradients are
|
||||
accumulated N times before weight update.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
|
||||
batch_size * accumulation_steps. Default: 1.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(TrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.accu = False
|
||||
self.is_accu_step = Tensor(np.array([self.accu]))
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.accumulation_steps = context.get_auto_parallel_context("grad_accumulation_step")
|
||||
self.one = Tensor(np.array([1]).astype(np.int32))
|
||||
self.zero = Tensor(np.array([0]).astype(np.int32))
|
||||
self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
|
||||
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32))
|
||||
self.accu_loss = Parameter(initializer(0, [1], mstype.float32))
|
||||
self.reducer_flag = False
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.degree = 1
|
||||
self.grad_reducer = F.identity
|
||||
if self.reducer_flag:
|
||||
self.degree = _get_device_num()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.overflow_reducer = F.identity
|
||||
if self.is_distributed:
|
||||
self.overflow_reducer = P.AllReduce()
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.logical_or = P.LogicalOr()
|
||||
self.not_equal = P.NotEqual()
|
||||
self.select = P.Select()
|
||||
self.reshape = P.Reshape()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
||||
|
||||
@C.add_flags(has_effect=True)
|
||||
def construct(self, x, b, sens=None):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
loss = self.network(x, b)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
|
||||
# alloc status and clear should be right before gradoperation
|
||||
init = self.alloc_status()
|
||||
self.clear_before_grad(init)
|
||||
grads = self.grad(self.network, weights)(x, b, self.cast(scaling_sens, mstype.float32))
|
||||
|
||||
if self.is_accu_step and self.accumulation_steps > 1:
|
||||
accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads)
|
||||
loss = F.depend(loss, accu_succ)
|
||||
|
||||
self.get_status(init)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
overflow = self.less_equal(self.base, flag_sum)
|
||||
overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
|
||||
accu_overflow = self.select(overflow, self.one, self.zero)
|
||||
self.accu_overflow = self.select(self.is_accu_step, accu_overflow, self.zero)
|
||||
|
||||
if self.is_accu_step:
|
||||
succ = False
|
||||
else:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
scaling = scaling_sens * self.degree * self.accumulation_steps
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling), grads)
|
||||
grads = ClipByGlobalNorm()(grads)
|
||||
accu_overflow = self.overflow_reducer(accu_overflow)
|
||||
F.control_depend(grads, accu_overflow)
|
||||
overflow = self.less_equal(self.base, accu_overflow)
|
||||
accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
|
||||
overflow = F.depend(overflow, accu_succ)
|
||||
overflow = self.reshape(overflow, (()))
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, overflow)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
|
||||
ret = (loss, overflow, scaling_sens)
|
||||
return F.depend(ret, succ)
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, weight, strategy=None):
|
||||
super().__init__()
|
||||
self.mul = P.Mul().shard(strategy)
|
||||
self.weight = Parameter(weight, "w1")
|
||||
self.relu = P.ReLU()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.mul(x, self.weight)
|
||||
out = self.relu(out)
|
||||
out = self.reduce_sum(out)
|
||||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([2]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([16]), dtype=ms.float32)
|
||||
_w1 = Tensor(np.ones([16]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile_net(net):
|
||||
context.set_context(enable_sparse=False)
|
||||
learning_rate = 0.1
|
||||
momentum = 0.9
|
||||
epoch_size = 2
|
||||
dataset = Dataset(_x, _b)
|
||||
opt = Momentum(net.trainable_params(), learning_rate, momentum)
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536, scale_factor=2, scale_window=1000)
|
||||
net_wrap = TrainAccumulateStepsWithLossScaleCell(net, opt, scale_update_cell=update_cell)
|
||||
model = Model(net_wrap)
|
||||
model.train(epoch_size, dataset, dataset_sink_mode=False)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_grad_accumulation_accu():
|
||||
grad_accumulation_step = 4
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0,
|
||||
grad_accumulation_step=grad_accumulation_step)
|
||||
strategy = ((2,), (2,))
|
||||
net = Net(_w1, strategy).add_flags_recursive(accu=True)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_grad_accu_and_opt_shard_accu():
|
||||
grad_accumulation_step = 4
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0,
|
||||
grad_accumulation_step=grad_accumulation_step, enable_parallel_optimizer=True)
|
||||
strategy = ((2,), (2,))
|
||||
net = Net(_w1, strategy).add_flags_recursive(accu=True)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_grad_accumulation_not_accu():
|
||||
grad_accumulation_step = 4
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0,
|
||||
grad_accumulation_step=grad_accumulation_step)
|
||||
strategy = ((2,), (2,))
|
||||
net = Net(_w1, strategy).add_flags_recursive(accu=False)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_grad_accu_and_opt_shard_not_accu():
|
||||
grad_accumulation_step = 4
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0,
|
||||
grad_accumulation_step=grad_accumulation_step, enable_parallel_optimizer=True)
|
||||
strategy = ((2,), (2,))
|
||||
net = Net(_w1, strategy).add_flags_recursive(accu=False)
|
||||
compile_net(net)
|
Loading…
Reference in New Issue