From a0874000d8a3ab1f40a1c5d19d6cbb023c622804 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 5 Oct 2020 15:26:21 +0800 Subject: [PATCH] Fix FusedBatchNormEx and SparseSoftmax for GPU --- .../auto_parallel/rec_core/rec_generate_strategy.cc | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc index 02a06c9bebc..18623e9d923 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -368,14 +368,19 @@ Strategys MakeDataParallelStrategy(const std::shared_ptr &graph, for (size_t dim = 0; dim < input_size; dim++) { if (input_size == 1 || input_size == 2 || input_size == 4) { if (dim == 0) { - s.push_back(std::min(max_device_num, target_tensor_batch)); + // Currently GPU version does not support partitioning ‘FusedBatchNormEx’ in its param tensors. + if (ops[iter_ops]->type() == "FusedBatchNormEx" && iter_op_inputs != 0) { + s.push_back(1); + } else { + s.push_back(std::min(max_device_num, target_tensor_batch)); + } } else { s.push_back(1); } } else if (input_size == 0) { s = {}; } else { - MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown."; + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor shape " << input_size << " is unexpected."; } } strategies.push_back(s); @@ -416,6 +421,8 @@ Strategys PrepareStrategy(const std::shared_ptr &graph, const std::vector return PrepareMatMul(graph, ops, iter_graph, iter_ops); } else if (type == ONEHOT) { return PrepareOneHot(graph, ops, iter_graph, iter_ops); + } else if (type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) { + return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); } else { return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); }