Fix FusedBatchNormEx and SparseSoftmax for GPU

This commit is contained in:
root 2020-10-05 15:26:21 +08:00
parent c1b9efe8e6
commit a0874000d8
1 changed files with 9 additions and 2 deletions

View File

@ -368,14 +368,19 @@ Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
for (size_t dim = 0; dim < input_size; dim++) {
if (input_size == 1 || input_size == 2 || input_size == 4) {
if (dim == 0) {
s.push_back(std::min(max_device_num, target_tensor_batch));
// Currently GPU version does not support partitioning FusedBatchNormEx in its param tensors.
if (ops[iter_ops]->type() == "FusedBatchNormEx" && iter_op_inputs != 0) {
s.push_back(1);
} else {
s.push_back(std::min(max_device_num, target_tensor_batch));
}
} else {
s.push_back(1);
}
} else if (input_size == 0) {
s = {};
} else {
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown.";
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor shape " << input_size << " is unexpected.";
}
}
strategies.push_back(s);
@ -416,6 +421,8 @@ Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector
return PrepareMatMul(graph, ops, iter_graph, iter_ops);
} else if (type == ONEHOT) {
return PrepareOneHot(graph, ops, iter_graph, iter_ops);
} else if (type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
} else {
return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
}