forked from mindspore-Ecosystem/mindspore
!3272 [AutoParallel] Adjust partition strategy of elementwise operators for the implicit boardcast cases
Merge pull request !3272 from Chong/NewWideAndDeep
This commit is contained in:
commit
b29fab3e9c
|
@ -614,7 +614,6 @@ std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::sh
|
|||
std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops,
|
||||
std::vector<int32_t> basic_stra) {
|
||||
std::vector<int32_t> s_empty = {};
|
||||
std::vector<std::vector<int32_t>> stra;
|
||||
MS_EXCEPTION_IF_NULL(ops[iter_ops]);
|
||||
|
||||
|
@ -636,9 +635,99 @@ std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vect
|
|||
if (ops[iter_ops]->type() == L2_NORMALIZE) {
|
||||
return PrepareL2Normalize(ops, iter_ops, basic_stra);
|
||||
}
|
||||
if (ops[iter_ops]->type() == TENSOR_ADD || ops[iter_ops]->type() == SUB || ops[iter_ops]->type() == MUL ||
|
||||
ops[iter_ops]->type() == DIV) {
|
||||
return CheckBroadcast(ops, iter_ops, basic_stra);
|
||||
}
|
||||
|
||||
return CheckDivisible(ops, iter_ops, basic_stra);
|
||||
}
|
||||
|
||||
// Function to deal with ops with broadcasting, like TensorAdd/Sub/Mul/Div etc.
|
||||
std::vector<std::vector<int32_t>> CheckBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops, std::vector<int32_t> s) {
|
||||
std::vector<std::vector<int32_t>> stra;
|
||||
|
||||
size_t first_tensor_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size();
|
||||
size_t second_tensor_dim = ops[iter_ops]->inputs_tensor_info()[1].shape().size();
|
||||
|
||||
// Do Broadcasting in the second tensor.
|
||||
if (second_tensor_dim < first_tensor_dim) {
|
||||
bool braoadcast_first_tensor = false;
|
||||
// Push back the first tensor's strategy.
|
||||
stra.push_back(s);
|
||||
// Push back the second tensor's strategy after applying broadcast.
|
||||
stra.push_back(ApplyBroadcast(ops, iter_ops, s, second_tensor_dim, first_tensor_dim, braoadcast_first_tensor));
|
||||
} else if (second_tensor_dim > first_tensor_dim) { // Do Broadcasting in the first tensor.
|
||||
bool braoadcast_first_tensor = true;
|
||||
// Push back the first tensor's strategy after applying broadcast.
|
||||
stra.push_back(ApplyBroadcast(ops, iter_ops, s, first_tensor_dim, second_tensor_dim, braoadcast_first_tensor));
|
||||
// Push back the second tensor's strategy.
|
||||
stra.push_back(s);
|
||||
} else { // Broadcasting can be ignored or No broadcasting needs to be applied.
|
||||
stra = CheckDivisible(ops, iter_ops, s);
|
||||
}
|
||||
|
||||
return stra;
|
||||
}
|
||||
|
||||
std::vector<int32_t> ApplyBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
||||
std::vector<int32_t> s, size_t target_tensor_dim, size_t refer_tensor_dim,
|
||||
bool braoadcast_first_tensor) {
|
||||
std::vector<int32_t> s_empty = {};
|
||||
std::vector<int32_t> s_broadcast;
|
||||
int target_tensor_index = 0;
|
||||
int refer_tensor_index = 0;
|
||||
|
||||
// Indexing target and refer tensor.
|
||||
if (braoadcast_first_tensor) {
|
||||
target_tensor_index = 0;
|
||||
refer_tensor_index = 1;
|
||||
} else {
|
||||
target_tensor_index = 1;
|
||||
refer_tensor_index = 0;
|
||||
}
|
||||
|
||||
// When target tensor with an empty dim.
|
||||
if (target_tensor_dim == 0) {
|
||||
return s_empty;
|
||||
} else if (target_tensor_dim == 1) { // When target tensor with a single dim.
|
||||
bool broadcast_dim_found = false;
|
||||
for (size_t iter = 0; iter < refer_tensor_dim; iter++) {
|
||||
// Find and copy that dim's strategy from the refer tensor.
|
||||
if ((ops[iter_ops]->inputs_tensor_info()[refer_tensor_index].shape()[iter] ==
|
||||
ops[iter_ops]->inputs_tensor_info()[target_tensor_index].shape()[0]) &&
|
||||
(ops[iter_ops]->inputs_tensor_info()[refer_tensor_index].shape()[iter] > 1) &&
|
||||
(refer_tensor_dim == s.size())) {
|
||||
s_broadcast.push_back(s.at(iter));
|
||||
broadcast_dim_found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Cannot decide which dim it is, push back one.
|
||||
if (broadcast_dim_found == false) {
|
||||
s_broadcast.push_back(1);
|
||||
}
|
||||
} else {
|
||||
// Cannot decide which dim needs to do broadcast, push back one(s).
|
||||
for (size_t iter = 0; iter < target_tensor_dim; iter++) {
|
||||
s_broadcast.push_back(1);
|
||||
}
|
||||
}
|
||||
|
||||
return s_broadcast;
|
||||
}
|
||||
|
||||
// Check whether the operator can be divided by the current strategy.
|
||||
std::vector<std::vector<int32_t>> CheckDivisible(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops, std::vector<int32_t> basic_stra) {
|
||||
std::vector<int32_t> s_empty = {};
|
||||
std::vector<std::vector<int32_t>> stra;
|
||||
|
||||
// For all the input tensors.
|
||||
for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size();
|
||||
iter_op_inputs++) {
|
||||
// If input tensor is empty, return strategy as void.
|
||||
if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size() == 0) {
|
||||
stra.push_back(s_empty);
|
||||
continue;
|
||||
|
@ -646,6 +735,8 @@ std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vect
|
|||
|
||||
std::vector<int32_t> tmp_stra = basic_stra;
|
||||
bool modified = false;
|
||||
|
||||
// Make sure each tensor's dim shape is greater than 1. If not, push back strategy as 1 instead.
|
||||
for (size_t j = 0; j < (size_t)ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size(); j++) {
|
||||
if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape()[j] == 1) {
|
||||
tmp_stra[j] = 1;
|
||||
|
@ -658,6 +749,7 @@ std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vect
|
|||
stra.push_back(basic_stra);
|
||||
}
|
||||
}
|
||||
|
||||
return stra;
|
||||
}
|
||||
|
||||
|
|
|
@ -42,6 +42,13 @@ std::vector<std::vector<int32_t>> PrepareGatherV2(const std::vector<std::shared_
|
|||
const size_t iter_ops, std::vector<int32_t> s);
|
||||
std::vector<std::vector<int32_t>> PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops, std::vector<int32_t> s);
|
||||
std::vector<std::vector<int32_t>> CheckBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops, std::vector<int32_t> s);
|
||||
std::vector<int32_t> ApplyBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
||||
std::vector<int32_t> s, size_t target_tensor_dim, size_t refer_tensor_dim,
|
||||
bool braoadcast_first_tensor);
|
||||
std::vector<std::vector<int32_t>> CheckDivisible(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_ops, std::vector<int32_t> s);
|
||||
std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
|
||||
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_graph, const size_t iter_ops);
|
||||
|
|
Loading…
Reference in New Issue