forked from mindspore-Ecosystem/mindspore
!9122 [AutoParallel] add LayerNorm Dropout and SegmentSum/Max/Min for GPT
From: @ch-l Reviewed-by: @kisnwang,@stsuteng Signed-off-by: @stsuteng
This commit is contained in:
commit
a09f1e30b6
|
@ -309,7 +309,14 @@ Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph,
|
|||
}
|
||||
|
||||
std::vector<int64_t> axis_list;
|
||||
auto iter = ops[iter_ops]->attrs().find(AXIS);
|
||||
string axis_name = AXIS;
|
||||
int64_t default_axis = -1;
|
||||
if (ops[iter_ops]->type() == LAYER_NORM) {
|
||||
axis_name = "begin_norm_axis";
|
||||
default_axis = 1;
|
||||
}
|
||||
|
||||
auto iter = ops[iter_ops]->attrs().find(axis_name);
|
||||
if (iter != ops[iter_ops]->attrs().end()) {
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
if (iter->second->isa<Int64Imm>()) {
|
||||
|
@ -326,8 +333,9 @@ Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph,
|
|||
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int64_t or tuple int64_t.";
|
||||
}
|
||||
} else {
|
||||
axis_list.push_back(-1);
|
||||
axis_list.push_back(default_axis);
|
||||
}
|
||||
|
||||
for (auto &axis : axis_list) {
|
||||
if (axis < 0) {
|
||||
int64_t input_dim = SizeToLong(ops[iter_ops]->inputs_tensor_info()[0].shape().size());
|
||||
|
@ -481,10 +489,10 @@ Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector
|
|||
return PrepareMatMul(graph, ops, iter_graph, iter_ops);
|
||||
} else if (type == ONEHOT) {
|
||||
return PrepareOneHot(graph, ops, iter_graph, iter_ops);
|
||||
} else if (type == SOFTMAX) {
|
||||
} else if ((type == SOFTMAX) || (type == LAYER_NORM)) {
|
||||
return PrepareAxisRelatedStrategy(graph, ops, iter_graph, iter_ops);
|
||||
} else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "_VirtualDataset") ||
|
||||
(type == "FusedBatchNormEx") || (type == "Dropout")) {
|
||||
(type == "FusedBatchNormEx") || (type == "Dropout") || (type == BATCH_MATMUL)) {
|
||||
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
|
||||
} else {
|
||||
return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
|
||||
|
|
|
@ -51,7 +51,8 @@ enum OperatorType {
|
|||
kRecReduce,
|
||||
kRecPReLU,
|
||||
kRecGatherV2,
|
||||
kRecArgWithValue
|
||||
kRecArgWithValue,
|
||||
kRecUnsortedSegmentOp
|
||||
};
|
||||
|
||||
enum InfoType { kApplication, kConstant };
|
||||
|
|
|
@ -61,6 +61,10 @@ Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>>
|
|||
NewOp.tensor_parm = MakeTensor(
|
||||
ops[iter_ops]->outputs_tensor_info()[0].shape()[0], ops[iter_ops]->outputs_tensor_info()[0].shape()[1],
|
||||
ops[iter_ops]->outputs_tensor_info()[0].shape()[2], ops[iter_ops]->outputs_tensor_info()[0].shape()[3]);
|
||||
} else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 3) {
|
||||
NewOp.tensor_parm = MakeTensor(1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0],
|
||||
ops[iter_ops]->outputs_tensor_info()[0].shape()[1],
|
||||
ops[iter_ops]->outputs_tensor_info()[0].shape()[2]);
|
||||
} else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) {
|
||||
NewOp.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0],
|
||||
ops[iter_ops]->outputs_tensor_info()[0].shape()[1]);
|
||||
|
@ -69,7 +73,7 @@ Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>>
|
|||
} else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 0) {
|
||||
NewOp.tensor_parm = MakeTensor(1, 1, 1, 1);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Tensor's shape is unknown.";
|
||||
MS_LOG(ERROR) << ops[iter_ops]->name() << ": output tensor shape is unexpected.";
|
||||
}
|
||||
|
||||
NewOp.apply = CompleteOperatorInputs(ops, iter_ops, NewOp);
|
||||
|
@ -90,6 +94,11 @@ OperatorRec CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInf
|
|||
ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1],
|
||||
ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[2],
|
||||
ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[3]);
|
||||
} else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 3) {
|
||||
NewTensor.apply.arguments[iter_input_tensors] =
|
||||
MakeTensor(1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0],
|
||||
ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1],
|
||||
ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[2]);
|
||||
} else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 2) {
|
||||
NewTensor.apply.arguments[iter_input_tensors] = Complete2DInputs(ops, iter_ops, iter_input_tensors, NewTensor);
|
||||
} else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 1) {
|
||||
|
@ -98,7 +107,7 @@ OperatorRec CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInf
|
|||
} else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 0) {
|
||||
NewTensor.apply.arguments[iter_input_tensors] = MakeTensor(1, 1, 1, 1);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Tensor's shape is unknown.";
|
||||
MS_LOG(ERROR) << ops[iter_ops]->name() << ": input tensor shape is unexpected.";
|
||||
}
|
||||
}
|
||||
return NewTensor.apply;
|
||||
|
|
|
@ -47,6 +47,7 @@ const std::map<std::string, OperatorType> DictOpType{
|
|||
{BIAS_ADD, OperatorType::kRecBiasAdd},
|
||||
{BATCH_NORM, OperatorType::kRecBatchNorm},
|
||||
{FUSE_BATCH_NORM, OperatorType::kRecBatchNorm},
|
||||
{LAYER_NORM, OperatorType::kRecBatchNorm},
|
||||
{SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits},
|
||||
{ONEHOT, OperatorType::kRecOneHot},
|
||||
{SQUEEZE, OperatorType::kRecSqueeze},
|
||||
|
@ -58,6 +59,9 @@ const std::map<std::string, OperatorType> DictOpType{
|
|||
{GATHERV2, OperatorType::kRecGatherV2},
|
||||
{ARGMAXWITHVALUE, OperatorType::kRecArgWithValue},
|
||||
{ARGMINWITHVALUE, OperatorType::kRecArgWithValue},
|
||||
{UNSORTED_SEGMENT_SUM, OperatorType::kRecUnsortedSegmentOp},
|
||||
{UNSORTED_SEGMENT_MAX, OperatorType::kRecUnsortedSegmentOp},
|
||||
{UNSORTED_SEGMENT_MIN, OperatorType::kRecUnsortedSegmentOp},
|
||||
// Activation OP
|
||||
{ACTIVATION, OperatorType::kRecReLU},
|
||||
{RELU, OperatorType::kRecReLU},
|
||||
|
@ -139,7 +143,8 @@ const std::map<std::string, OperatorType> DictOpType{
|
|||
{ASSIGN, OperatorType::kRecElmWiseOp},
|
||||
{ASSIGN_ADD, OperatorType::kRecElmWiseOp},
|
||||
{ASSIGN_SUB, OperatorType::kRecElmWiseOp},
|
||||
{"AssignAdd", OperatorType::kRecElmWiseOp}};
|
||||
{"AssignAdd", OperatorType::kRecElmWiseOp},
|
||||
{DROPOUT_DO_MASK, OperatorType::kRecElmWiseOp}};
|
||||
|
||||
const TensorParam MakeTensor(int64_t n, int64_t c, int64_t h, int64_t w);
|
||||
|
||||
|
|
|
@ -76,7 +76,8 @@ double GetWeights(const Graph::NodeType &node) {
|
|||
|
||||
return cost_ptr->GetMinCostIn();
|
||||
} else if (op.op_type == OperatorType::kRecBatchNorm || op.op_type == OperatorType::kRecOneHot ||
|
||||
op.op_type == OperatorType::kRecPReLU || op.op_type == OperatorType::kRecSoftmax ||
|
||||
op.op_type == OperatorType::kRecPReLU || op.op_type == OperatorType::kRecUnsortedSegmentOp ||
|
||||
op.op_type == OperatorType::kRecSoftmax ||
|
||||
op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits ||
|
||||
op.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) {
|
||||
// For BatchParallel op
|
||||
|
@ -172,7 +173,8 @@ StrategyRec PartitionNode(const Graph::NodeType &node,
|
|||
return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
|
||||
} else if (node.apply.op_type == OperatorType::kRecBatchNorm || node.apply.op_type == OperatorType::kRecOneHot ||
|
||||
node.apply.op_type == OperatorType::kRecPReLU || node.apply.op_type == kRecSoftmax ||
|
||||
node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) {
|
||||
node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits ||
|
||||
node.apply.op_type == kRecUnsortedSegmentOp) {
|
||||
// For BatchParallel type
|
||||
auto cost_ptr = std::make_shared<CostBatchParallel>();
|
||||
return cost_ptr->GetOptimalStr(node);
|
||||
|
|
Loading…
Reference in New Issue