forked from mindspore-Ecosystem/mindspore
add cell comm recompute interface
This commit is contained in:
parent
6b3aedea1c
commit
5277b229be
|
@ -33,8 +33,8 @@ namespace {
|
|||
constexpr auto kGradientsFlag = "Gradients";
|
||||
|
||||
bool CanNotRecomputed(const CNodePtr &node) {
|
||||
static std::unordered_set<PrimitivePtr> not_recomputed_op_list{prim::kPrimAllGather, prim::kPrimDropoutGenMask,
|
||||
prim::kPrimLoad, prim::kPrimTupleGetItem};
|
||||
static std::unordered_set<PrimitivePtr> not_recomputed_op_list{prim::kPrimDropoutGenMask, prim::kPrimLoad,
|
||||
prim::kPrimTupleGetItem};
|
||||
|
||||
return std::any_of(not_recomputed_op_list.begin(), not_recomputed_op_list.end(),
|
||||
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
||||
|
|
|
@ -134,6 +134,7 @@ constexpr char FUSION[] = "fusion";
|
|||
constexpr char DO_MIRROR[] = "do_mirror";
|
||||
constexpr char RECOMPUTE[] = "recompute";
|
||||
constexpr char RECOMPUTE_COMM_OP[] = "recompute_comm_op";
|
||||
constexpr char NOT_RECOMPUTE[] = "not_recompute";
|
||||
constexpr char NUM_SAMPLED[] = "num_sampled";
|
||||
constexpr char NUM_TRUE[] = "num_true";
|
||||
constexpr char SEED[] = "seed";
|
||||
|
@ -193,7 +194,7 @@ constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter";
|
|||
constexpr char FIELD_SIZE[] = "field_size";
|
||||
constexpr char OPTIMIZER_SUB_STRING[] = "optimizer";
|
||||
constexpr char DEVICE[] = "Device";
|
||||
constexpr char PARALLEL_OPTIMIZER_ALLGATHER[] = "parallel_optimizer_allgather";
|
||||
constexpr char PARALLEL_OPTIMIZER_ALLGATHER[] = "parallel_optimizer_allgather_not_recompute";
|
||||
constexpr char CELLLIST_KEYWORD_PATTERN[] = "-CellList/(\\d+)-";
|
||||
|
||||
constexpr char OUT_CHANNEL[] = "out_channel";
|
||||
|
|
|
@ -261,6 +261,9 @@ void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const An
|
|||
PrimitivePtr new_node_prim = new_node_value->value()->cast<PrimitivePtr>();
|
||||
new_node_prim->set_instance_name(instance_name);
|
||||
new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
|
||||
if (instance_name.find(NOT_RECOMPUTE) != std::string::npos) {
|
||||
new_node_prim->set_attr("recompute", MakeValue(false));
|
||||
}
|
||||
new_node->set_scope(scope);
|
||||
node_input[0]->set_scope(scope);
|
||||
manager->SetEdge(node, SizeToLong(index), new_node);
|
||||
|
@ -290,6 +293,9 @@ static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, cons
|
|||
auto new_node_prim = GetValueNode<PrimitivePtr>(node_input[0]);
|
||||
new_node_prim->set_instance_name(instance_name);
|
||||
new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
|
||||
if (instance_name.find(NOT_RECOMPUTE) != std::string::npos) {
|
||||
new_node_prim->set_attr("recompute", MakeValue(false));
|
||||
}
|
||||
new_node->set_scope(scope);
|
||||
node_input[0]->set_scope(scope);
|
||||
manager->Replace(pre_node, new_node);
|
||||
|
@ -394,6 +400,18 @@ void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_p
|
|||
std::string op_name = (redistribution_oplist_ptr->first)[index].first;
|
||||
std::string instance_name_base = REDISTRIBUTION_OP;
|
||||
std::string instance_name = instance_name_base + "_" + CreateInstanceName(pre_node, index) + op_name;
|
||||
auto prim_out = GetCNodePrimitive(node);
|
||||
auto prim_in = GetCNodePrimitive(pre_node);
|
||||
if (prim_out != nullptr && prim_in != nullptr) {
|
||||
auto prim_out_attr = prim_out->attrs();
|
||||
auto prim_in_attr = prim_in->attrs();
|
||||
if (prim_out_attr.find(RECOMPUTE_COMM_OP) != prim_out_attr.end() &&
|
||||
prim_in_attr.find(RECOMPUTE_COMM_OP) != prim_in_attr.end() &&
|
||||
COMMUNICATION_OPS.find(op_name) != COMMUNICATION_OPS.end()) {
|
||||
MS_LOG(INFO) << "The redistribution node would not be recomputed.";
|
||||
instance_name = instance_name + "_" + NOT_RECOMPUTE;
|
||||
}
|
||||
}
|
||||
InsertNode(op, node, LongToSize(pos), target_node, func_graph, instance_name);
|
||||
if ((redistribution_oplist_ptr->second)[index].first) {
|
||||
target_node = node->input(LongToSize(pos));
|
||||
|
@ -443,12 +461,7 @@ TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr &
|
|||
}
|
||||
|
||||
std::string GetPrimName(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!IsValueNode<Primitive>(node->input(0))) {
|
||||
MS_LOG(EXCEPTION) << "The node is not a primitive";
|
||||
}
|
||||
auto value_node = node->input(0)->cast<ValueNodePtr>();
|
||||
auto prim = GetValueNode<PrimitivePtr>(value_node);
|
||||
auto prim = GetCNodePrimitive(node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
return prim->name();
|
||||
}
|
||||
|
@ -881,6 +894,11 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
|
|||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(replace_node->input(0));
|
||||
PrimitivePtr origin_prim = GetValueNode<PrimitivePtr>(node->input(0));
|
||||
SetUserAttrs(origin_prim->attrs(), prim);
|
||||
if (origin_prim->attrs().find(RECOMPUTE_COMM_OP) != origin_prim->attrs().end() &&
|
||||
COMMUNICATION_OPS.find(prim->name()) != COMMUNICATION_OPS.end()) {
|
||||
MS_LOG(INFO) << "The redistribution node in reshape would not be recomputed.";
|
||||
prim->set_attr("recompute", MakeValue(false));
|
||||
}
|
||||
if (index == replace_op.size() - 1) {
|
||||
replace_node->set_user_data<OperatorInfo>(node->user_data<OperatorInfo>());
|
||||
replace_node->set_primal_attrs(node->primal_attrs());
|
||||
|
|
|
@ -187,6 +187,8 @@ std::string MirrorOpName();
|
|||
|
||||
CommInfo GetCommInfo();
|
||||
|
||||
std::string GetPrimName(const CNodePtr &node);
|
||||
|
||||
void ReorderForPipelineSplit(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager, int64_t pipeline_stages);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,6 +21,7 @@ from collections import OrderedDict
|
|||
|
||||
import numpy
|
||||
|
||||
from mindspore._checkparam import args_type_check
|
||||
from mindspore import log as logger
|
||||
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
|
||||
from mindspore.common._decorator import deprecated
|
||||
|
@ -85,6 +86,7 @@ class Cell(Cell_):
|
|||
self._cells = OrderedDict()
|
||||
self._params_list = OrderedDict()
|
||||
self._tensor_list = OrderedDict()
|
||||
self._primitives = OrderedDict()
|
||||
self.training = False
|
||||
self.requires_grad = False
|
||||
self.pynative = False
|
||||
|
@ -510,6 +512,7 @@ class Cell(Cell_):
|
|||
else:
|
||||
if isinstance(value, Primitive):
|
||||
value.set_prim_instance_name(name)
|
||||
self._primitives[name] = value
|
||||
object.__setattr__(self, name, value)
|
||||
if name not in Cell.IGNORE_LIST:
|
||||
self._attr_synced = False
|
||||
|
@ -1287,7 +1290,26 @@ class Cell(Cell_):
|
|||
elif not self._scope is None and self._scope.startswith(prefix):
|
||||
self._scope = self._scope[len(prefix):]
|
||||
|
||||
def recompute(self, mode=True, output_recompute=False):
|
||||
def _mp_comm_recompute(self, mp_comm_recompute=True):
|
||||
for _, value in self._primitives.items():
|
||||
if value:
|
||||
value.add_prim_attr("recompute_comm_op", mp_comm_recompute)
|
||||
for cell in self.cells():
|
||||
cell._mp_comm_recompute(mp_comm_recompute)
|
||||
|
||||
def _recompute(self, mode=True, output_recompute=False):
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
raise TypeError("Recompute is not supported in pynative mode currently.")
|
||||
Validator.check_bool(mode)
|
||||
Validator.check_bool(output_recompute)
|
||||
self._set_recompute_scope(mode)
|
||||
if mode and not output_recompute:
|
||||
self.add_flags(output_no_recompute=True)
|
||||
for cell in self.cells():
|
||||
cell._recompute(mode, True)
|
||||
|
||||
@args_type_check(mode=bool, output_recompute=bool, mp_comm_recompute=bool)
|
||||
def recompute(self, **kwargs):
|
||||
"""
|
||||
Set the cell recomputed. All the primitive in the cell will be set recomputed. If a primitive
|
||||
set recomputed feeds into some backward nodes for computing gradient, rather than storing the
|
||||
|
@ -1304,16 +1326,25 @@ class Cell(Cell_):
|
|||
mode (bool): Specifies whether the cell is recomputed. Default: True.
|
||||
output_recompute (bool): Specifies whether the output of this cell is recomputed when
|
||||
the mode is true. Note that when the mode is false, this arg is not working. Default: False.
|
||||
mp_comm_recompute (bool): Specifies whether the model parallel communication operators in the
|
||||
cell is recomputed in auto parallel or semi auto parallel mode. Default: True.
|
||||
"""
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
raise TypeError("Recompute is not supported in pynative mode currently.")
|
||||
Validator.check_bool(mode)
|
||||
Validator.check_bool(output_recompute)
|
||||
self._set_recompute_scope(mode)
|
||||
if mode and not output_recompute:
|
||||
self.add_flags(output_no_recompute=True)
|
||||
for cell in self.cells():
|
||||
cell.recompute(mode, True)
|
||||
if not kwargs:
|
||||
self._recompute()
|
||||
if 'mode' in kwargs.keys() or 'output_recompute' in kwargs.keys():
|
||||
mode = True
|
||||
output_recompute = False
|
||||
if 'mode' in kwargs.keys():
|
||||
mode = kwargs['mode']
|
||||
if 'output_recompute' in kwargs.keys():
|
||||
output_recompute = kwargs['output_recompute']
|
||||
self._recompute(mode, output_recompute)
|
||||
if 'mp_comm_recompute' in kwargs.keys():
|
||||
self._mp_comm_recompute(kwargs['mp_comm_recompute'])
|
||||
for key, _ in kwargs.items():
|
||||
if key not in ('mode', 'output_recompute', 'mp_comm_recompute'):
|
||||
raise ValueError("Recompute keyword %s is not recognized!" % key)
|
||||
|
||||
|
||||
def infer_param_pipeline_stage(self):
|
||||
"""
|
||||
|
|
|
@ -38,7 +38,7 @@ def test_set_recompute_true():
|
|||
|
||||
def test_set_recompute_false():
|
||||
net = Net()
|
||||
net.pool.recompute(False)
|
||||
net.pool.recompute(mode=False)
|
||||
assert net.pool.get_scope() is None
|
||||
|
||||
|
||||
|
@ -51,32 +51,32 @@ def test_set_recompute_true_twice():
|
|||
|
||||
def test_set_recompute_false_twice():
|
||||
net = Net()
|
||||
net.pool.recompute(False)
|
||||
net.pool.recompute(False)
|
||||
net.pool.recompute(mode=False)
|
||||
net.pool.recompute(mode=False)
|
||||
assert net.pool.get_scope() is None
|
||||
|
||||
|
||||
def test_reset_recompute1():
|
||||
net = Net()
|
||||
net.pool.recompute(True)
|
||||
net.pool.recompute(False)
|
||||
net.pool.recompute(mode=True)
|
||||
net.pool.recompute(mode=False)
|
||||
assert net.pool.get_scope() == ""
|
||||
|
||||
|
||||
def test_reset_recompute2():
|
||||
net = Net()
|
||||
net.pool.recompute(False)
|
||||
net.pool.recompute(True)
|
||||
net.pool.recompute(mode=False)
|
||||
net.pool.recompute(mode=True)
|
||||
assert net.pool.get_scope() == recompute_prefix
|
||||
|
||||
|
||||
def test_set_scope_and_set_recompute_repeatedly():
|
||||
net = Net()
|
||||
net.pool.recompute(True)
|
||||
net.pool.recompute(mode=True)
|
||||
assert net.pool.get_scope() == recompute_prefix
|
||||
net.pool.recompute(False)
|
||||
net.pool.recompute(mode=False)
|
||||
assert net.pool.get_scope() == ""
|
||||
net.pool.recompute(True)
|
||||
net.pool.recompute(mode=True)
|
||||
assert net.pool.get_scope() == recompute_prefix
|
||||
net.pool.recompute(False)
|
||||
net.pool.recompute(mode=False)
|
||||
assert net.pool.get_scope() == ""
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor, context, Parameter
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.context import _Context
|
||||
from ....train_step_wrap import train_step_with_loss_warp
|
||||
|
||||
class MatMulCell(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MatMulCell, self).__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.matmul0 = P.MatMul()
|
||||
self.weight = Parameter(initializer("ones", [128, 64], ms.float32), name="weight")
|
||||
self.relu = P.ReLU().shard(((1, 8),))
|
||||
def construct(self, x):
|
||||
x = self.matmul0(x, self.weight)
|
||||
x = self.reshape(x, (32, 128))
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
class DenseMutMulNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(DenseMutMulNet, self).__init__()
|
||||
self.fc1 = nn.Dense(128, 768, activation='relu')
|
||||
self.fc2 = nn.Dense(128, 768, activation='relu')
|
||||
self.fc3 = nn.Dense(128, 768, activation='relu')
|
||||
self.fc4 = nn.Dense(768, 768, activation='relu')
|
||||
self.fc1.recompute()
|
||||
self.fc2.recompute()
|
||||
self.fc3.recompute()
|
||||
self.fc1.matmul.shard(((1, 1), (1, 8)))
|
||||
self.fc2.matmul.shard(((1, 1), (1, 8)))
|
||||
self.fc3.matmul.shard(((1, 1), (1, 8)))
|
||||
self.relu4 = nn.ReLU()
|
||||
self.relu5 = nn.ReLU()
|
||||
self.transpose = P.Transpose()
|
||||
self.matmul1 = P.MatMul()
|
||||
self.matmul2 = P.MatMul()
|
||||
self.matmul_cell = MatMulCell()
|
||||
self.matmul_cell.recompute()
|
||||
self.fc1.recompute(mp_comm_recompute=False)
|
||||
self.fc2.recompute(mp_comm_recompute=False)
|
||||
self.fc3.recompute(mp_comm_recompute=False)
|
||||
self.matmul_cell.recompute(mp_comm_recompute=False)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.matmul_cell(x)
|
||||
q = self.fc1(x)
|
||||
k = self.fc2(x)
|
||||
v = self.fc3(x)
|
||||
k = self.transpose(k, (1, 0))
|
||||
c = self.relu4(self.matmul1(q, k))
|
||||
s = self.relu5(self.matmul2(c, v))
|
||||
s = self.fc4(s)
|
||||
return s
|
||||
|
||||
|
||||
def test_dmnet_train_step():
|
||||
context.reset_auto_parallel_context()
|
||||
_Context().set_backend_policy("vm")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8)
|
||||
input_ = Tensor(np.ones([64, 128]).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.zeros([32, 768]).astype(np.float32))
|
||||
net = train_step_with_loss_warp(DenseMutMulNet())
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
_executor.compile(net, input_, label)
|
||||
_Context().set_backend_policy("ge")
|
Loading…
Reference in New Issue