!23434 [CT][MS][parallel] fix VirtualData outgoing op strategy copying bug

Merge pull request !23434 from Chong/PanGu-VirtualDataset
This commit is contained in:
i-robot 2021-09-16 03:54:35 +00:00 committed by Gitee
commit f112c42027
2 changed files with 40 additions and 8 deletions

View File

@ -153,6 +153,31 @@ Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s) {
return strategies;
}
Strategys PrepareStridedSlice(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
Dimensions basic_stra) {
Strategys stra;
auto begin = GetValue<std::vector<int64_t>>(ops[iter_ops]->input_value().at(1));
auto end = GetValue<std::vector<int64_t>>(ops[iter_ops]->input_value().at(2));
auto strides = GetValue<std::vector<int64_t>>(ops[iter_ops]->input_value().at(3));
for (size_t i = 0; i < strides.size(); ++i) {
if ((strides[i] != 1) && (basic_stra[i] > 1)) {
basic_stra[i] = 1;
}
}
for (size_t i = 0; i < begin.size(); ++i) {
bool no_fully_fetch = ((begin[i] != 0) || (end[i] < ops[iter_ops]->inputs_tensor_info()[0].shape()[i]));
if (no_fully_fetch && (basic_stra[i] != 1)) {
basic_stra[i] = 1;
}
}
stra.push_back(basic_stra);
return stra;
}
Strategys PrepareOneHot(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops) {
Strategys strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
@ -634,16 +659,18 @@ Dimensions CopyVirtualDataset(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
const size_t iter_graph) {
Dimensions s;
for (auto input : ops[iter_ops]->inputs_tensor_info()) {
auto input_stra_dim = input.shape().size();
auto virtual_dataset_str = CheckVirtualDatasetStrategy(graph, iter_graph);
if (input_stra_dim == 0) {
continue;
auto input_stra_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size();
auto virtual_dataset_str = CheckVirtualDatasetStrategy(graph, iter_graph);
if (input_stra_dim == 0) {
return s;
} else {
if (virtual_dataset_str == 0) {
s.push_back(1);
} else {
s.push_back(1 / virtual_dataset_str);
for (size_t i = 1; i < input_stra_dim; i++) {
s.push_back(1);
}
}
for (size_t i = 1; i < input_stra_dim; i++) {
s.push_back(1);
}
}
return s;
@ -930,6 +957,9 @@ Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<Opera
if (ops[iter_ops]->type() == BIAS_ADD) {
return PrepareBiasAdd(s_ptr);
}
if (ops[iter_ops]->type() == STRIDED_SLICE) {
return PrepareStridedSlice(ops, iter_ops, basic_stra);
}
if (ops[iter_ops]->type() == GATHERV2) {
auto pos = ops[iter_ops]->name().find("Info");
auto name = ops[iter_ops]->name().substr(0, pos);

View File

@ -34,6 +34,8 @@ void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std
Strategys PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops);
Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s);
Strategys PrepareStridedSlice(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
Dimensions basic_stra);
Strategys PrepareOneHot(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops);
Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph,