forked from mindspore-Ecosystem/mindspore
Fix FusedBatchNormEx and SparseSoftmax for GPU
This commit is contained in:
parent
c1b9efe8e6
commit
a0874000d8
|
@ -368,14 +368,19 @@ Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
|
|||
for (size_t dim = 0; dim < input_size; dim++) {
|
||||
if (input_size == 1 || input_size == 2 || input_size == 4) {
|
||||
if (dim == 0) {
|
||||
s.push_back(std::min(max_device_num, target_tensor_batch));
|
||||
// Currently GPU version does not support partitioning ‘FusedBatchNormEx’ in its param tensors.
|
||||
if (ops[iter_ops]->type() == "FusedBatchNormEx" && iter_op_inputs != 0) {
|
||||
s.push_back(1);
|
||||
} else {
|
||||
s.push_back(std::min(max_device_num, target_tensor_batch));
|
||||
}
|
||||
} else {
|
||||
s.push_back(1);
|
||||
}
|
||||
} else if (input_size == 0) {
|
||||
s = {};
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown.";
|
||||
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor shape " << input_size << " is unexpected.";
|
||||
}
|
||||
}
|
||||
strategies.push_back(s);
|
||||
|
@ -416,6 +421,8 @@ Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector
|
|||
return PrepareMatMul(graph, ops, iter_graph, iter_ops);
|
||||
} else if (type == ONEHOT) {
|
||||
return PrepareOneHot(graph, ops, iter_graph, iter_ops);
|
||||
} else if (type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
|
||||
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
|
||||
} else {
|
||||
return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue