forked from mindspore-Ecosystem/mindspore
!26570 [Auto-par][d-rec] Change Onehot OP type to increase partitioning quality
Merge pull request !26570 from petitquentin/Gather_version_update
This commit is contained in:
commit
46e53a51c9
58
mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc
Executable file → Normal file
58
mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc
Executable file → Normal file
|
@ -170,30 +170,42 @@ Strategys PrepareSoftMax(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
|||
return strategies;
|
||||
}
|
||||
|
||||
Strategys PrepareOneHot(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_graph, const size_t iter_ops) {
|
||||
Strategys strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
|
||||
Strategys PrepareOneHot(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s) {
|
||||
Strategys strategies;
|
||||
|
||||
int64_t axis = -1;
|
||||
auto iter = ops[iter_ops]->attrs().find(AXIS);
|
||||
if (iter != ops[iter_ops]->attrs().end()) {
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
if (iter->second->isa<Int64Imm>()) {
|
||||
axis = iter->second->cast<Int64ImmPtr>()->value();
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int64_t.";
|
||||
// The Dimension size of the first input tensor of OneHot should be 2 even its Shape size is 1. Using the division of
|
||||
// the number of devices and the partition parts of the first dimension.
|
||||
|
||||
size_t s_second = 1;
|
||||
|
||||
if (s[0] != 0) {
|
||||
s_second = g_device_manager->DeviceNum() / s[0];
|
||||
}
|
||||
|
||||
if (s.size() == 1) {
|
||||
s.push_back(s_second);
|
||||
}
|
||||
|
||||
// Partition number should not exceed the number of devices
|
||||
size_t n_parts = 1;
|
||||
for (size_t i = 0; i < ops[iter_ops]->outputs_tensor_info()[0].shape().size(); i++) {
|
||||
n_parts *= ops[iter_ops]->outputs_tensor_info()[0].shape()[i];
|
||||
}
|
||||
|
||||
if (n_parts > s_second) {
|
||||
s.clear();
|
||||
for (size_t i = 0; i < ops[iter_ops]->outputs_tensor_info()[0].shape().size(); i++) {
|
||||
s.push_back(1);
|
||||
}
|
||||
}
|
||||
if (axis == -1) {
|
||||
strategies[0][0] = strategies[0][1];
|
||||
strategies[0][1] = 1;
|
||||
graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = graph->nodes[iter_graph].tensor_parm.tensor_str.str_w;
|
||||
graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0;
|
||||
}
|
||||
|
||||
strategies.push_back(s);
|
||||
|
||||
// Push two empty Dimensions for the other two input tensors.
|
||||
Dimensions s_empty = {};
|
||||
strategies.push_back(s_empty);
|
||||
strategies.push_back(s_empty);
|
||||
|
||||
return strategies;
|
||||
}
|
||||
|
||||
|
@ -565,8 +577,6 @@ Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector
|
|||
auto type = ops[iter_ops]->type();
|
||||
if (type == MATMUL) {
|
||||
return PrepareMatMul(graph, ops, iter_graph, iter_ops);
|
||||
} else if (type == ONEHOT) {
|
||||
return PrepareOneHot(graph, ops, iter_graph, iter_ops);
|
||||
} else if (type == LAYER_NORM) {
|
||||
return PrepareAxisRelatedStrategy(graph, ops, iter_graph, iter_ops);
|
||||
} else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == DROPOUT) || (type == BATCH_MATMUL)) {
|
||||
|
@ -896,8 +906,11 @@ Dimensions ModifyStrategyIfArgIncoming(const std::vector<std::shared_ptr<Operato
|
|||
}
|
||||
|
||||
Dimensions CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t incoming_op_index) {
|
||||
const size_t iter_ops, const size_t incoming_op_index) {
|
||||
Dimensions s;
|
||||
if (ops[iter_ops]->type() == ONEHOT) {
|
||||
return s;
|
||||
}
|
||||
s = PrepareIncomingOperatorInputStrategy(ops, incoming_op_index);
|
||||
if (s.size() != 0) {
|
||||
if (ops[incoming_op_index]->type() == SQUEEZE) {
|
||||
|
@ -947,6 +960,9 @@ Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<Opera
|
|||
MS_LOG(EXCEPTION) << "Failure: Unknown type of GatherV2." << std::endl;
|
||||
}
|
||||
}
|
||||
if (ops[iter_ops]->type() == ONEHOT) {
|
||||
return PrepareOneHot(ops, iter_ops, basic_stra);
|
||||
}
|
||||
if (ops[iter_ops]->type() == L2_NORMALIZE) {
|
||||
return PrepareL2Normalize(ops, iter_ops, basic_stra);
|
||||
}
|
||||
|
@ -1105,7 +1121,7 @@ void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> &gra
|
|||
if (iter_graph != SIZE_MAX) {
|
||||
s = CopyIncomingOperatorOutputStrategy(graph, ops, iter_ops, iter_graph, incoming_op_index);
|
||||
} else {
|
||||
s = CopyIncomingOperatorInputStrategy(ops, incoming_op_index);
|
||||
s = CopyIncomingOperatorInputStrategy(ops, iter_ops, incoming_op_index);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -40,8 +40,7 @@ Strategys PrepareStridedSlice(const std::vector<std::shared_ptr<OperatorInfo>> &
|
|||
Dimensions basic_stra);
|
||||
Strategys PrepareSoftMax(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
||||
Dimensions basic_stra);
|
||||
Strategys PrepareOneHot(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_graph, const size_t iter_ops);
|
||||
Strategys PrepareOneHot(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
|
||||
Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
|
||||
const size_t iter_ops);
|
||||
|
@ -88,7 +87,7 @@ Dimensions GetDimListFromAttrs(const std::vector<std::shared_ptr<OperatorInfo>>
|
|||
Dimensions ModifyStrategyIfArgIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t incoming_op_index, Dimensions s);
|
||||
Dimensions CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t incoming_op_index);
|
||||
const size_t iter_ops, const size_t incoming_op_index);
|
||||
Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
||||
Dimensions basic_stra);
|
||||
void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> &graph,
|
||||
|
|
|
@ -33,7 +33,8 @@ static const std::set<OperatorType> ElementWiseOpType = {
|
|||
OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, OperatorType::kRecAdd,
|
||||
OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, OperatorType::kRecMul,
|
||||
OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, OperatorType::kRecCast,
|
||||
OperatorType::kRecReshape, OperatorType::kRecGatherV2, OperatorType::kRecArgWithValue, OperatorType::kRecSoftmax};
|
||||
OperatorType::kRecReshape, OperatorType::kRecGatherV2, OperatorType::kRecArgWithValue, OperatorType::kRecSoftmax,
|
||||
OperatorType::kRecOneHot};
|
||||
|
||||
const std::map<std::string, OperatorType> DictOpType{
|
||||
{MATMUL, OperatorType::kRecMatMul},
|
||||
|
|
Loading…
Reference in New Issue