add cell comm recompute interface

This commit is contained in:
yao_yf 2021-08-12 21:07:55 +08:00
parent 6b3aedea1c
commit 5277b229be
7 changed files with 169 additions and 30 deletions

View File

@ -33,8 +33,8 @@ namespace {
constexpr auto kGradientsFlag = "Gradients";
bool CanNotRecomputed(const CNodePtr &node) {
static std::unordered_set<PrimitivePtr> not_recomputed_op_list{prim::kPrimAllGather, prim::kPrimDropoutGenMask,
prim::kPrimLoad, prim::kPrimTupleGetItem};
static std::unordered_set<PrimitivePtr> not_recomputed_op_list{prim::kPrimDropoutGenMask, prim::kPrimLoad,
prim::kPrimTupleGetItem};
return std::any_of(not_recomputed_op_list.begin(), not_recomputed_op_list.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });

View File

@ -134,6 +134,7 @@ constexpr char FUSION[] = "fusion";
constexpr char DO_MIRROR[] = "do_mirror";
constexpr char RECOMPUTE[] = "recompute";
constexpr char RECOMPUTE_COMM_OP[] = "recompute_comm_op";
constexpr char NOT_RECOMPUTE[] = "not_recompute";
constexpr char NUM_SAMPLED[] = "num_sampled";
constexpr char NUM_TRUE[] = "num_true";
constexpr char SEED[] = "seed";
@ -193,7 +194,7 @@ constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter";
constexpr char FIELD_SIZE[] = "field_size";
constexpr char OPTIMIZER_SUB_STRING[] = "optimizer";
constexpr char DEVICE[] = "Device";
constexpr char PARALLEL_OPTIMIZER_ALLGATHER[] = "parallel_optimizer_allgather";
constexpr char PARALLEL_OPTIMIZER_ALLGATHER[] = "parallel_optimizer_allgather_not_recompute";
constexpr char CELLLIST_KEYWORD_PATTERN[] = "-CellList/(\\d+)-";
constexpr char OUT_CHANNEL[] = "out_channel";

View File

@ -261,6 +261,9 @@ void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const An
PrimitivePtr new_node_prim = new_node_value->value()->cast<PrimitivePtr>();
new_node_prim->set_instance_name(instance_name);
new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
if (instance_name.find(NOT_RECOMPUTE) != std::string::npos) {
new_node_prim->set_attr("recompute", MakeValue(false));
}
new_node->set_scope(scope);
node_input[0]->set_scope(scope);
manager->SetEdge(node, SizeToLong(index), new_node);
@ -290,6 +293,9 @@ static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, cons
auto new_node_prim = GetValueNode<PrimitivePtr>(node_input[0]);
new_node_prim->set_instance_name(instance_name);
new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
if (instance_name.find(NOT_RECOMPUTE) != std::string::npos) {
new_node_prim->set_attr("recompute", MakeValue(false));
}
new_node->set_scope(scope);
node_input[0]->set_scope(scope);
manager->Replace(pre_node, new_node);
@ -394,6 +400,18 @@ void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_p
std::string op_name = (redistribution_oplist_ptr->first)[index].first;
std::string instance_name_base = REDISTRIBUTION_OP;
std::string instance_name = instance_name_base + "_" + CreateInstanceName(pre_node, index) + op_name;
auto prim_out = GetCNodePrimitive(node);
auto prim_in = GetCNodePrimitive(pre_node);
if (prim_out != nullptr && prim_in != nullptr) {
auto prim_out_attr = prim_out->attrs();
auto prim_in_attr = prim_in->attrs();
if (prim_out_attr.find(RECOMPUTE_COMM_OP) != prim_out_attr.end() &&
prim_in_attr.find(RECOMPUTE_COMM_OP) != prim_in_attr.end() &&
COMMUNICATION_OPS.find(op_name) != COMMUNICATION_OPS.end()) {
MS_LOG(INFO) << "The redistribution node would not be recomputed.";
instance_name = instance_name + "_" + NOT_RECOMPUTE;
}
}
InsertNode(op, node, LongToSize(pos), target_node, func_graph, instance_name);
if ((redistribution_oplist_ptr->second)[index].first) {
target_node = node->input(LongToSize(pos));
@ -443,12 +461,7 @@ TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr &
}
std::string GetPrimName(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!IsValueNode<Primitive>(node->input(0))) {
MS_LOG(EXCEPTION) << "The node is not a primitive";
}
auto value_node = node->input(0)->cast<ValueNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(value_node);
auto prim = GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(prim);
return prim->name();
}
@ -881,6 +894,11 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
PrimitivePtr prim = GetValueNode<PrimitivePtr>(replace_node->input(0));
PrimitivePtr origin_prim = GetValueNode<PrimitivePtr>(node->input(0));
SetUserAttrs(origin_prim->attrs(), prim);
if (origin_prim->attrs().find(RECOMPUTE_COMM_OP) != origin_prim->attrs().end() &&
COMMUNICATION_OPS.find(prim->name()) != COMMUNICATION_OPS.end()) {
MS_LOG(INFO) << "The redistribution node in reshape would not be recomputed.";
prim->set_attr("recompute", MakeValue(false));
}
if (index == replace_op.size() - 1) {
replace_node->set_user_data<OperatorInfo>(node->user_data<OperatorInfo>());
replace_node->set_primal_attrs(node->primal_attrs());

View File

@ -187,6 +187,8 @@ std::string MirrorOpName();
CommInfo GetCommInfo();
std::string GetPrimName(const CNodePtr &node);
void ReorderForPipelineSplit(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager, int64_t pipeline_stages);
} // namespace parallel
} // namespace mindspore

View File

@ -21,6 +21,7 @@ from collections import OrderedDict
import numpy
from mindspore._checkparam import args_type_check
from mindspore import log as logger
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
from mindspore.common._decorator import deprecated
@ -85,6 +86,7 @@ class Cell(Cell_):
self._cells = OrderedDict()
self._params_list = OrderedDict()
self._tensor_list = OrderedDict()
self._primitives = OrderedDict()
self.training = False
self.requires_grad = False
self.pynative = False
@ -510,6 +512,7 @@ class Cell(Cell_):
else:
if isinstance(value, Primitive):
value.set_prim_instance_name(name)
self._primitives[name] = value
object.__setattr__(self, name, value)
if name not in Cell.IGNORE_LIST:
self._attr_synced = False
@ -1287,7 +1290,26 @@ class Cell(Cell_):
elif not self._scope is None and self._scope.startswith(prefix):
self._scope = self._scope[len(prefix):]
def recompute(self, mode=True, output_recompute=False):
def _mp_comm_recompute(self, mp_comm_recompute=True):
for _, value in self._primitives.items():
if value:
value.add_prim_attr("recompute_comm_op", mp_comm_recompute)
for cell in self.cells():
cell._mp_comm_recompute(mp_comm_recompute)
def _recompute(self, mode=True, output_recompute=False):
if context.get_context("mode") == context.PYNATIVE_MODE:
raise TypeError("Recompute is not supported in pynative mode currently.")
Validator.check_bool(mode)
Validator.check_bool(output_recompute)
self._set_recompute_scope(mode)
if mode and not output_recompute:
self.add_flags(output_no_recompute=True)
for cell in self.cells():
cell._recompute(mode, True)
@args_type_check(mode=bool, output_recompute=bool, mp_comm_recompute=bool)
def recompute(self, **kwargs):
"""
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive
set recomputed feeds into some backward nodes for computing gradient, rather than storing the
@ -1304,16 +1326,25 @@ class Cell(Cell_):
mode (bool): Specifies whether the cell is recomputed. Default: True.
output_recompute (bool): Specifies whether the output of this cell is recomputed when
the mode is true. Note that when the mode is false, this arg is not working. Default: False.
mp_comm_recompute (bool): Specifies whether the model parallel communication operators in the
cell is recomputed in auto parallel or semi auto parallel mode. Default: True.
"""
if context.get_context("mode") == context.PYNATIVE_MODE:
raise TypeError("Recompute is not supported in pynative mode currently.")
Validator.check_bool(mode)
Validator.check_bool(output_recompute)
self._set_recompute_scope(mode)
if mode and not output_recompute:
self.add_flags(output_no_recompute=True)
for cell in self.cells():
cell.recompute(mode, True)
if not kwargs:
self._recompute()
if 'mode' in kwargs.keys() or 'output_recompute' in kwargs.keys():
mode = True
output_recompute = False
if 'mode' in kwargs.keys():
mode = kwargs['mode']
if 'output_recompute' in kwargs.keys():
output_recompute = kwargs['output_recompute']
self._recompute(mode, output_recompute)
if 'mp_comm_recompute' in kwargs.keys():
self._mp_comm_recompute(kwargs['mp_comm_recompute'])
for key, _ in kwargs.items():
if key not in ('mode', 'output_recompute', 'mp_comm_recompute'):
raise ValueError("Recompute keyword %s is not recognized!" % key)
def infer_param_pipeline_stage(self):
"""

View File

@ -38,7 +38,7 @@ def test_set_recompute_true():
def test_set_recompute_false():
net = Net()
net.pool.recompute(False)
net.pool.recompute(mode=False)
assert net.pool.get_scope() is None
@ -51,32 +51,32 @@ def test_set_recompute_true_twice():
def test_set_recompute_false_twice():
net = Net()
net.pool.recompute(False)
net.pool.recompute(False)
net.pool.recompute(mode=False)
net.pool.recompute(mode=False)
assert net.pool.get_scope() is None
def test_reset_recompute1():
net = Net()
net.pool.recompute(True)
net.pool.recompute(False)
net.pool.recompute(mode=True)
net.pool.recompute(mode=False)
assert net.pool.get_scope() == ""
def test_reset_recompute2():
net = Net()
net.pool.recompute(False)
net.pool.recompute(True)
net.pool.recompute(mode=False)
net.pool.recompute(mode=True)
assert net.pool.get_scope() == recompute_prefix
def test_set_scope_and_set_recompute_repeatedly():
net = Net()
net.pool.recompute(True)
net.pool.recompute(mode=True)
assert net.pool.get_scope() == recompute_prefix
net.pool.recompute(False)
net.pool.recompute(mode=False)
assert net.pool.get_scope() == ""
net.pool.recompute(True)
net.pool.recompute(mode=True)
assert net.pool.get_scope() == recompute_prefix
net.pool.recompute(False)
net.pool.recompute(mode=False)
assert net.pool.get_scope() == ""

View File

@ -0,0 +1,87 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore.nn as nn
import mindspore as ms
from mindspore import Tensor, context, Parameter
from mindspore.common.api import _executor
from mindspore.ops import operations as P
from mindspore.common.initializer import initializer
from mindspore.context import _Context
from ....train_step_wrap import train_step_with_loss_warp
class MatMulCell(nn.Cell):
def __init__(self):
super(MatMulCell, self).__init__()
self.reshape = P.Reshape()
self.matmul0 = P.MatMul()
self.weight = Parameter(initializer("ones", [128, 64], ms.float32), name="weight")
self.relu = P.ReLU().shard(((1, 8),))
def construct(self, x):
x = self.matmul0(x, self.weight)
x = self.reshape(x, (32, 128))
x = self.relu(x)
return x
class DenseMutMulNet(nn.Cell):
def __init__(self):
super(DenseMutMulNet, self).__init__()
self.fc1 = nn.Dense(128, 768, activation='relu')
self.fc2 = nn.Dense(128, 768, activation='relu')
self.fc3 = nn.Dense(128, 768, activation='relu')
self.fc4 = nn.Dense(768, 768, activation='relu')
self.fc1.recompute()
self.fc2.recompute()
self.fc3.recompute()
self.fc1.matmul.shard(((1, 1), (1, 8)))
self.fc2.matmul.shard(((1, 1), (1, 8)))
self.fc3.matmul.shard(((1, 1), (1, 8)))
self.relu4 = nn.ReLU()
self.relu5 = nn.ReLU()
self.transpose = P.Transpose()
self.matmul1 = P.MatMul()
self.matmul2 = P.MatMul()
self.matmul_cell = MatMulCell()
self.matmul_cell.recompute()
self.fc1.recompute(mp_comm_recompute=False)
self.fc2.recompute(mp_comm_recompute=False)
self.fc3.recompute(mp_comm_recompute=False)
self.matmul_cell.recompute(mp_comm_recompute=False)
def construct(self, x):
x = self.matmul_cell(x)
q = self.fc1(x)
k = self.fc2(x)
v = self.fc3(x)
k = self.transpose(k, (1, 0))
c = self.relu4(self.matmul1(q, k))
s = self.relu5(self.matmul2(c, v))
s = self.fc4(s)
return s
def test_dmnet_train_step():
context.reset_auto_parallel_context()
_Context().set_backend_policy("vm")
context.set_context(mode=context.GRAPH_MODE)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8)
input_ = Tensor(np.ones([64, 128]).astype(np.float32) * 0.01)
label = Tensor(np.zeros([32, 768]).astype(np.float32))
net = train_step_with_loss_warp(DenseMutMulNet())
net.set_auto_parallel()
net.set_train()
_executor.compile(net, input_, label)
_Context().set_backend_policy("ge")