forked from mindspore-Ecosystem/mindspore
!23434 [CT][MS][parallel] fix VirtualData outgoing op strategy copying bug
Merge pull request !23434 from Chong/PanGu-VirtualDataset
This commit is contained in:
commit
f112c42027
|
@ -153,6 +153,31 @@ Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s) {
|
|||
return strategies;
|
||||
}
|
||||
|
||||
Strategys PrepareStridedSlice(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
||||
Dimensions basic_stra) {
|
||||
Strategys stra;
|
||||
|
||||
auto begin = GetValue<std::vector<int64_t>>(ops[iter_ops]->input_value().at(1));
|
||||
auto end = GetValue<std::vector<int64_t>>(ops[iter_ops]->input_value().at(2));
|
||||
auto strides = GetValue<std::vector<int64_t>>(ops[iter_ops]->input_value().at(3));
|
||||
|
||||
for (size_t i = 0; i < strides.size(); ++i) {
|
||||
if ((strides[i] != 1) && (basic_stra[i] > 1)) {
|
||||
basic_stra[i] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < begin.size(); ++i) {
|
||||
bool no_fully_fetch = ((begin[i] != 0) || (end[i] < ops[iter_ops]->inputs_tensor_info()[0].shape()[i]));
|
||||
if (no_fully_fetch && (basic_stra[i] != 1)) {
|
||||
basic_stra[i] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
stra.push_back(basic_stra);
|
||||
return stra;
|
||||
}
|
||||
|
||||
Strategys PrepareOneHot(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_graph, const size_t iter_ops) {
|
||||
Strategys strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
|
||||
|
@ -634,16 +659,18 @@ Dimensions CopyVirtualDataset(const std::shared_ptr<Graph> &graph,
|
|||
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
||||
const size_t iter_graph) {
|
||||
Dimensions s;
|
||||
for (auto input : ops[iter_ops]->inputs_tensor_info()) {
|
||||
auto input_stra_dim = input.shape().size();
|
||||
auto virtual_dataset_str = CheckVirtualDatasetStrategy(graph, iter_graph);
|
||||
if (input_stra_dim == 0) {
|
||||
continue;
|
||||
auto input_stra_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size();
|
||||
auto virtual_dataset_str = CheckVirtualDatasetStrategy(graph, iter_graph);
|
||||
if (input_stra_dim == 0) {
|
||||
return s;
|
||||
} else {
|
||||
if (virtual_dataset_str == 0) {
|
||||
s.push_back(1);
|
||||
} else {
|
||||
s.push_back(1 / virtual_dataset_str);
|
||||
for (size_t i = 1; i < input_stra_dim; i++) {
|
||||
s.push_back(1);
|
||||
}
|
||||
}
|
||||
for (size_t i = 1; i < input_stra_dim; i++) {
|
||||
s.push_back(1);
|
||||
}
|
||||
}
|
||||
return s;
|
||||
|
@ -930,6 +957,9 @@ Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<Opera
|
|||
if (ops[iter_ops]->type() == BIAS_ADD) {
|
||||
return PrepareBiasAdd(s_ptr);
|
||||
}
|
||||
if (ops[iter_ops]->type() == STRIDED_SLICE) {
|
||||
return PrepareStridedSlice(ops, iter_ops, basic_stra);
|
||||
}
|
||||
if (ops[iter_ops]->type() == GATHERV2) {
|
||||
auto pos = ops[iter_ops]->name().find("Info");
|
||||
auto name = ops[iter_ops]->name().substr(0, pos);
|
||||
|
|
|
@ -34,6 +34,8 @@ void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std
|
|||
Strategys PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_graph, const size_t iter_ops);
|
||||
Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s);
|
||||
Strategys PrepareStridedSlice(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
|
||||
Dimensions basic_stra);
|
||||
Strategys PrepareOneHot(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
||||
const size_t iter_graph, const size_t iter_ops);
|
||||
Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph,
|
||||
|
|
Loading…
Reference in New Issue