!4744 [AutoParallel]Support bert

Merge pull request !4744 from lichen/support_bert
This commit is contained in:
mindspore-ci-bot 2020-08-20 14:13:19 +08:00 committed by Gitee
commit 9ee144ea40
5 changed files with 272 additions and 1 deletions

View File

@ -610,6 +610,15 @@ Status MatMulBase::CheckForTensorSliceValid() const {
return SUCCESS; return SUCCESS;
} }
std::shared_ptr<Strategys> BatchMatMulInfo::GenerateBatchStrategies() {
CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
Dimensions batch_strategy(inputs_shape_[1].size() - 1, 1);
batch_strategy.insert(batch_strategy.begin(), SizeToLong(dev_num));
Strategys strategy_v = {batch_strategy, batch_strategy};
return std::make_shared<Strategys>(strategy_v);
}
Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) {
if (InitForCostModel(strategy) == FAILED) { if (InitForCostModel(strategy) == FAILED) {
if (is_auto_parallel_) { if (is_auto_parallel_) {

View File

@ -91,6 +91,8 @@ class BatchMatMulInfo : public MatMul {
const PrimitiveAttrs &attrs) const PrimitiveAttrs &attrs)
: MatMul(name, inputs_shape, outputs_shape, attrs) {} : MatMul(name, inputs_shape, outputs_shape, attrs) {}
~BatchMatMulInfo() override = default; ~BatchMatMulInfo() override = default;
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
}; };
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

View File

@ -162,6 +162,7 @@ constexpr char SIGMOID_CROSS_ENTROPY_WITH_LOGITS[] = "SigmoidCrossEntropyWithLog
constexpr char MATMUL[] = "MatMul"; constexpr char MATMUL[] = "MatMul";
constexpr char GELU[] = "Gelu"; constexpr char GELU[] = "Gelu";
constexpr char TANH[] = "Tanh"; constexpr char TANH[] = "Tanh";
constexpr char SHAPE_OP[] = "Shape";
constexpr char SOFTMAX[] = "Softmax"; constexpr char SOFTMAX[] = "Softmax";
constexpr char LOG_SOFTMAX[] = "LogSoftmax"; constexpr char LOG_SOFTMAX[] = "LogSoftmax";
constexpr char ACTIVATION[] = "Activation"; constexpr char ACTIVATION[] = "Activation";

View File

@ -1673,6 +1673,41 @@ std::shared_ptr<TensorLayout> CreateParameterLayout(const AnfNodePtr &node) {
return std::make_shared<TensorLayout>(input_tensor_layout); return std::make_shared<TensorLayout>(input_tensor_layout);
} }
RedistributionOpListPtr InferSensRedistribution(const AnfNodePtr &node, const TensorLayout &loss_layout) {
MS_EXCEPTION_IF_NULL(node);
TensorRedistribution tensor_redistribution;
// create stand alone layout:TensorMap:[all -1],dev_matrix:[dev_num].
CheckGlobalDeviceManager();
int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size());
TensorLayout stand_alone_layout;
Shapes inputs_shape = GetNodeShape(node);
if (inputs_shape.empty()) {
MS_LOG(EXCEPTION) << "InferSensRedistribution failed cause inputs shape is empty.";
}
Shape input_shape_array = inputs_shape[0];
if (input_shape_array.empty()) {
MS_LOG(INFO) << "No need to redistribution for sens.";
return nullptr;
}
// TensorMap
TensorMap stand_alone_tensor_map_array(SizeToInt(input_shape_array.size()), -1);
// Dev_matrix
Shape dev_matrix_array = {dev_num};
if (stand_alone_layout.InitFromVector(dev_matrix_array, stand_alone_tensor_map_array, input_shape_array) == FAILED) {
MS_LOG(EXCEPTION) << "Create tensor layout for Sens failed.";
}
// Infer Redistribution op list for stand alone and loss layout.
RankList dev_list = g_device_manager->GetDeviceListByStageId(0);
if (tensor_redistribution.Init(stand_alone_layout, loss_layout, dev_list) == FAILED) {
MS_LOG(EXCEPTION) << "Redistribution for Sens init failed.";
}
RedistributionOpListPtr sens_redistribution_list = tensor_redistribution.InferTensorRedistributionOperatorList();
MS_EXCEPTION_IF_NULL(sens_redistribution_list);
return sens_redistribution_list;
}
std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) { std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) {
if (node->isa<Parameter>()) { if (node->isa<Parameter>()) {
return CreateParameterLayout(node); return CreateParameterLayout(node);
@ -1897,7 +1932,18 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay
sens_tensor_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout)); sens_tensor_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout));
return; return;
} }
MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now."; if (sens_tensor_node->isa<CNode>()) {
auto op_list_ptr = InferSensRedistribution(sens_tensor_node, loss_grad_layout);
if (op_list_ptr == nullptr) {
return;
}
auto sens_tensor_cnode = sens_tensor_node->cast<CNodePtr>();
auto func_graph = grad_sens_node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
InsertRedistribution(op_list_ptr, grad_sens_node, func_graph, 1, sens_tensor_cnode);
return;
}
MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter or CNode, it is unsupported now.";
} }
// Use _GetTensorSlice operator to split the sens tensor // Use _GetTensorSlice operator to split the sens tensor
@ -2305,6 +2351,41 @@ std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph, const An
return root_forward_nodes; return root_forward_nodes;
} }
void InsertShapeOp(const CNodePtr &node, const AnfNodePtr &pre_node, const FuncGraphPtr &root) {
// shape op doesn't have params and attrs.
OperatorParams params;
OperatorAttrs attrs;
OperatorArgs args = std::make_pair(attrs, params);
Operator op = std::make_pair(SHAPE_OP, args);
InsertNode(op, node, 2, pre_node, root, "shape");
}
void HandleRootReshape(const std::vector<AnfNodePtr> &all_nodes) {
// If root graph has reshape op. Find the corresponding parameter.
// Reshape's shape is the shape of the parameter.
for (auto &node : all_nodes) {
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (!IsValueNode<Primitive>(cnode->input(0)) || cnode->in_forward_flag()) {
continue;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim->name() != RESHAPE) {
continue;
}
auto root = node->func_graph();
auto all_dfs_nodes = DeepLinkedGraphSearch(node);
for (auto r_iter = all_dfs_nodes.rbegin(); r_iter != all_dfs_nodes.rend(); ++r_iter) {
if ((*r_iter)->isa<Parameter>()) {
InsertShapeOp(cnode, *r_iter, root);
break;
}
}
}
}
void MarkForwardCNode(const FuncGraphPtr &root) { void MarkForwardCNode(const FuncGraphPtr &root) {
MS_EXCEPTION_IF_NULL(root); MS_EXCEPTION_IF_NULL(root);
auto all_nodes = root->nodes(); auto all_nodes = root->nodes();
@ -2456,6 +2537,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
// mark the forward cnodes, parallel only care these nodes // mark the forward cnodes, parallel only care these nodes
MarkForwardCNode(root); MarkForwardCNode(root);
HandleRootReshape(all_nodes);
if (FindCommunicationOp(all_nodes)) { if (FindCommunicationOp(all_nodes)) {
MS_LOG(EXCEPTION) << "The graph contain communication op"; MS_LOG(EXCEPTION) << "The graph contain communication op";

View File

@ -0,0 +1,177 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
from mindspore import Tensor
from mindspore import context
from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.nn.optim.momentum import Momentum
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
import mindspore.nn as nn
from mindspore.train import Model, ParallelMode
from tests.dataset_mock import MindData
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
clip_grad = C.MultitypeFuncGraph("clip_grad")
grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
@grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * reciprocal(scale)
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536, scale_factor=2, scale_window=1000)
@clip_grad.register("Number", "Number", "Tensor")
def _clip_grad(clip_type, clip_value, grad):
dt = F.dtype(grad)
if clip_type == 0:
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
else:
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
return new_grad
class TrainOneStepWithLossScaleCell(nn.Cell):
def __init__(self, network, optimizer, scale_update_cell=None):
super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation('grad',
get_by_list=True,
sens_param=True)
self.reducer_flag = False
self.grad_reducer = F.identity
self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.hyper_map = C.HyperMap()
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
name="loss_scale")
@C.add_flags(has_effect=True)
def construct(self, x, sens=None):
"""Defines the computation performed."""
weights = self.weights
loss = self.network(x)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
# alloc status and clear should be right before gradoperation
init = self.alloc_status()
self.clear_before_grad(init)
grads = self.grad(self.network, weights)(x, self.cast(scaling_sens, mstype.float32))
# apply grad reducer on grads
grads = self.grad_reducer(grads)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
cond = self.less_equal(self.base, flag_sum)
overflow = cond
if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, cond)
if overflow:
succ = False
else:
succ = self.optimizer(grads)
ret = (loss, cond, scaling_sens)
return F.depend(ret, succ)
class DatasetLenet(MindData):
def __init__(self, predict, label, length=3):
super(DatasetLenet, self).__init__()
self.predict = predict
self.label = label
self.index = 0
self.length = length
def __iter__(self):
return self
def __next__(self):
if self.index >= self.length:
raise StopIteration
self.index += 1
return self.predict, self.label
def reset(self):
self.index = 0
class LoopLayer(nn.Cell):
def __init__(self):
super(LoopLayer, self).__init__()
self.matmul = P.MatMul()
self.relu = P.ReLU()
self.matmul_weight = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight")
def construct(self, x):
out = self.matmul(x, self.matmul_weight)
out = self.relu(out)
return out
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.exp = P.Exp()
self.mean = P.ReduceMean()
layers = []
for _ in range(3):
layer = LoopLayer()
layers.append(layer)
self.layers = nn.CellList(layers)
def construct(self, x):
out = self.exp(x)
for layer in self.layers:
layer_out = layer(out)
out = layer_out
out = self.mean(out, -1)
return out
def test_loss_scale():
context.set_context(mode=context.GRAPH_MODE)
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=8)
predict = Tensor(np.ones([64, 64]), dtype=ms.float32)
label = Tensor(np.ones([64,]), dtype=ms.int32)
dataset = DatasetLenet(predict, label)
net = Net()
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
net = TrainOneStepWithLossScaleCell(net, opt, update_cell)
model = Model(network=net)
model.train(2, dataset, dataset_sink_mode=False)