!26570 [Auto-par][d-rec] Change Onehot OP type to increase partitioning quality

Merge pull request !26570 from petitquentin/Gather_version_update
This commit is contained in:
i-robot 2021-11-23 14:31:56 +00:00 committed by Gitee
commit 46e53a51c9
3 changed files with 41 additions and 25 deletions

View File

@ -170,30 +170,42 @@ Strategys PrepareSoftMax(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
return strategies;
}
Strategys PrepareOneHot(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops) {
Strategys strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
Strategys PrepareOneHot(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s) {
Strategys strategies;
int64_t axis = -1;
auto iter = ops[iter_ops]->attrs().find(AXIS);
if (iter != ops[iter_ops]->attrs().end()) {
MS_EXCEPTION_IF_NULL(iter->second);
if (iter->second->isa<Int64Imm>()) {
axis = iter->second->cast<Int64ImmPtr>()->value();
} else {
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int64_t.";
// The Dimension size of the first input tensor of OneHot should be 2 even its Shape size is 1. Using the division of
// the number of devices and the partition parts of the first dimension.
size_t s_second = 1;
if (s[0] != 0) {
s_second = g_device_manager->DeviceNum() / s[0];
}
if (s.size() == 1) {
s.push_back(s_second);
}
// Partition number should not exceed the number of devices
size_t n_parts = 1;
for (size_t i = 0; i < ops[iter_ops]->outputs_tensor_info()[0].shape().size(); i++) {
n_parts *= ops[iter_ops]->outputs_tensor_info()[0].shape()[i];
}
if (n_parts > s_second) {
s.clear();
for (size_t i = 0; i < ops[iter_ops]->outputs_tensor_info()[0].shape().size(); i++) {
s.push_back(1);
}
}
if (axis == -1) {
strategies[0][0] = strategies[0][1];
strategies[0][1] = 1;
graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = graph->nodes[iter_graph].tensor_parm.tensor_str.str_w;
graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0;
}
strategies.push_back(s);
// Push two empty Dimensions for the other two input tensors.
Dimensions s_empty = {};
strategies.push_back(s_empty);
strategies.push_back(s_empty);
return strategies;
}
@ -565,8 +577,6 @@ Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector
auto type = ops[iter_ops]->type();
if (type == MATMUL) {
return PrepareMatMul(graph, ops, iter_graph, iter_ops);
} else if (type == ONEHOT) {
return PrepareOneHot(graph, ops, iter_graph, iter_ops);
} else if (type == LAYER_NORM) {
return PrepareAxisRelatedStrategy(graph, ops, iter_graph, iter_ops);
} else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == DROPOUT) || (type == BATCH_MATMUL)) {
@ -896,8 +906,11 @@ Dimensions ModifyStrategyIfArgIncoming(const std::vector<std::shared_ptr<Operato
}
Dimensions CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t incoming_op_index) {
const size_t iter_ops, const size_t incoming_op_index) {
Dimensions s;
if (ops[iter_ops]->type() == ONEHOT) {
return s;
}
s = PrepareIncomingOperatorInputStrategy(ops, incoming_op_index);
if (s.size() != 0) {
if (ops[incoming_op_index]->type() == SQUEEZE) {
@ -947,6 +960,9 @@ Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<Opera
MS_LOG(EXCEPTION) << "Failure: Unknown type of GatherV2." << std::endl;
}
}
if (ops[iter_ops]->type() == ONEHOT) {
return PrepareOneHot(ops, iter_ops, basic_stra);
}
if (ops[iter_ops]->type() == L2_NORMALIZE) {
return PrepareL2Normalize(ops, iter_ops, basic_stra);
}
@ -1105,7 +1121,7 @@ void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> &gra
if (iter_graph != SIZE_MAX) {
s = CopyIncomingOperatorOutputStrategy(graph, ops, iter_ops, iter_graph, incoming_op_index);
} else {
s = CopyIncomingOperatorInputStrategy(ops, incoming_op_index);
s = CopyIncomingOperatorInputStrategy(ops, iter_ops, incoming_op_index);
}
}

View File

@ -40,8 +40,7 @@ Strategys PrepareStridedSlice(const std::vector<std::shared_ptr<OperatorInfo>> &
Dimensions basic_stra);
Strategys PrepareSoftMax(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
Dimensions basic_stra);
Strategys PrepareOneHot(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops);
Strategys PrepareOneHot(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
const size_t iter_ops);
@ -88,7 +87,7 @@ Dimensions GetDimListFromAttrs(const std::vector<std::shared_ptr<OperatorInfo>>
Dimensions ModifyStrategyIfArgIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t incoming_op_index, Dimensions s);
Dimensions CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t incoming_op_index);
const size_t iter_ops, const size_t incoming_op_index);
Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
Dimensions basic_stra);
void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> &graph,

View File

@ -33,7 +33,8 @@ static const std::set<OperatorType> ElementWiseOpType = {
OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, OperatorType::kRecAdd,
OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, OperatorType::kRecMul,
OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, OperatorType::kRecCast,
OperatorType::kRecReshape, OperatorType::kRecGatherV2, OperatorType::kRecArgWithValue, OperatorType::kRecSoftmax};
OperatorType::kRecReshape, OperatorType::kRecGatherV2, OperatorType::kRecArgWithValue, OperatorType::kRecSoftmax,
OperatorType::kRecOneHot};
const std::map<std::string, OperatorType> DictOpType{
{MATMUL, OperatorType::kRecMatMul},