!9122 [AutoParallel] add LayerNorm Dropout and SegmentSum/Max/Min for GPT

From: @ch-l
Reviewed-by: @kisnwang,@stsuteng
Signed-off-by: @stsuteng
This commit is contained in:
mindspore-ci-bot 2020-11-30 17:33:26 +08:00 committed by Gitee
commit a09f1e30b6
5 changed files with 35 additions and 10 deletions

View File

@ -309,7 +309,14 @@ Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph,
}
std::vector<int64_t> axis_list;
auto iter = ops[iter_ops]->attrs().find(AXIS);
string axis_name = AXIS;
int64_t default_axis = -1;
if (ops[iter_ops]->type() == LAYER_NORM) {
axis_name = "begin_norm_axis";
default_axis = 1;
}
auto iter = ops[iter_ops]->attrs().find(axis_name);
if (iter != ops[iter_ops]->attrs().end()) {
MS_EXCEPTION_IF_NULL(iter->second);
if (iter->second->isa<Int64Imm>()) {
@ -326,8 +333,9 @@ Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph,
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int64_t or tuple int64_t.";
}
} else {
axis_list.push_back(-1);
axis_list.push_back(default_axis);
}
for (auto &axis : axis_list) {
if (axis < 0) {
int64_t input_dim = SizeToLong(ops[iter_ops]->inputs_tensor_info()[0].shape().size());
@ -481,10 +489,10 @@ Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector
return PrepareMatMul(graph, ops, iter_graph, iter_ops);
} else if (type == ONEHOT) {
return PrepareOneHot(graph, ops, iter_graph, iter_ops);
} else if (type == SOFTMAX) {
} else if ((type == SOFTMAX) || (type == LAYER_NORM)) {
return PrepareAxisRelatedStrategy(graph, ops, iter_graph, iter_ops);
} else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "_VirtualDataset") ||
(type == "FusedBatchNormEx") || (type == "Dropout")) {
(type == "FusedBatchNormEx") || (type == "Dropout") || (type == BATCH_MATMUL)) {
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
} else {
return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);

View File

@ -51,7 +51,8 @@ enum OperatorType {
kRecReduce,
kRecPReLU,
kRecGatherV2,
kRecArgWithValue
kRecArgWithValue,
kRecUnsortedSegmentOp
};
enum InfoType { kApplication, kConstant };

View File

@ -61,6 +61,10 @@ Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>>
NewOp.tensor_parm = MakeTensor(
ops[iter_ops]->outputs_tensor_info()[0].shape()[0], ops[iter_ops]->outputs_tensor_info()[0].shape()[1],
ops[iter_ops]->outputs_tensor_info()[0].shape()[2], ops[iter_ops]->outputs_tensor_info()[0].shape()[3]);
} else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 3) {
NewOp.tensor_parm = MakeTensor(1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0],
ops[iter_ops]->outputs_tensor_info()[0].shape()[1],
ops[iter_ops]->outputs_tensor_info()[0].shape()[2]);
} else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) {
NewOp.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0],
ops[iter_ops]->outputs_tensor_info()[0].shape()[1]);
@ -69,7 +73,7 @@ Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>>
} else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 0) {
NewOp.tensor_parm = MakeTensor(1, 1, 1, 1);
} else {
MS_LOG(ERROR) << "Tensor's shape is unknown.";
MS_LOG(ERROR) << ops[iter_ops]->name() << ": output tensor shape is unexpected.";
}
NewOp.apply = CompleteOperatorInputs(ops, iter_ops, NewOp);
@ -90,6 +94,11 @@ OperatorRec CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInf
ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1],
ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[2],
ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[3]);
} else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 3) {
NewTensor.apply.arguments[iter_input_tensors] =
MakeTensor(1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0],
ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1],
ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[2]);
} else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 2) {
NewTensor.apply.arguments[iter_input_tensors] = Complete2DInputs(ops, iter_ops, iter_input_tensors, NewTensor);
} else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 1) {
@ -98,7 +107,7 @@ OperatorRec CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInf
} else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 0) {
NewTensor.apply.arguments[iter_input_tensors] = MakeTensor(1, 1, 1, 1);
} else {
MS_LOG(ERROR) << "Tensor's shape is unknown.";
MS_LOG(ERROR) << ops[iter_ops]->name() << ": input tensor shape is unexpected.";
}
}
return NewTensor.apply;

View File

@ -47,6 +47,7 @@ const std::map<std::string, OperatorType> DictOpType{
{BIAS_ADD, OperatorType::kRecBiasAdd},
{BATCH_NORM, OperatorType::kRecBatchNorm},
{FUSE_BATCH_NORM, OperatorType::kRecBatchNorm},
{LAYER_NORM, OperatorType::kRecBatchNorm},
{SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits},
{ONEHOT, OperatorType::kRecOneHot},
{SQUEEZE, OperatorType::kRecSqueeze},
@ -58,6 +59,9 @@ const std::map<std::string, OperatorType> DictOpType{
{GATHERV2, OperatorType::kRecGatherV2},
{ARGMAXWITHVALUE, OperatorType::kRecArgWithValue},
{ARGMINWITHVALUE, OperatorType::kRecArgWithValue},
{UNSORTED_SEGMENT_SUM, OperatorType::kRecUnsortedSegmentOp},
{UNSORTED_SEGMENT_MAX, OperatorType::kRecUnsortedSegmentOp},
{UNSORTED_SEGMENT_MIN, OperatorType::kRecUnsortedSegmentOp},
// Activation OP
{ACTIVATION, OperatorType::kRecReLU},
{RELU, OperatorType::kRecReLU},
@ -139,7 +143,8 @@ const std::map<std::string, OperatorType> DictOpType{
{ASSIGN, OperatorType::kRecElmWiseOp},
{ASSIGN_ADD, OperatorType::kRecElmWiseOp},
{ASSIGN_SUB, OperatorType::kRecElmWiseOp},
{"AssignAdd", OperatorType::kRecElmWiseOp}};
{"AssignAdd", OperatorType::kRecElmWiseOp},
{DROPOUT_DO_MASK, OperatorType::kRecElmWiseOp}};
const TensorParam MakeTensor(int64_t n, int64_t c, int64_t h, int64_t w);

View File

@ -76,7 +76,8 @@ double GetWeights(const Graph::NodeType &node) {
return cost_ptr->GetMinCostIn();
} else if (op.op_type == OperatorType::kRecBatchNorm || op.op_type == OperatorType::kRecOneHot ||
op.op_type == OperatorType::kRecPReLU || op.op_type == OperatorType::kRecSoftmax ||
op.op_type == OperatorType::kRecPReLU || op.op_type == OperatorType::kRecUnsortedSegmentOp ||
op.op_type == OperatorType::kRecSoftmax ||
op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits ||
op.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) {
// For BatchParallel op
@ -172,7 +173,8 @@ StrategyRec PartitionNode(const Graph::NodeType &node,
return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
} else if (node.apply.op_type == OperatorType::kRecBatchNorm || node.apply.op_type == OperatorType::kRecOneHot ||
node.apply.op_type == OperatorType::kRecPReLU || node.apply.op_type == kRecSoftmax ||
node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) {
node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits ||
node.apply.op_type == kRecUnsortedSegmentOp) {
// For BatchParallel type
auto cost_ptr = std::make_shared<CostBatchParallel>();
return cost_ptr->GetOptimalStr(node);