!6839 Add functions that the split size can be -1 at the end of split ops

Merge pull request !6839 from liuwenhao/master
This commit is contained in:
mindspore-ci-bot 2020-09-24 17:20:46 +08:00 committed by Gitee
commit 2ef51bdeae
2 changed files with 8 additions and 8 deletions

View File

@ -50,14 +50,6 @@ int DoSplit(float *in_data, float **out_data, const int *input_shape, int offset
split_which = i % num_split; split_which = i % num_split;
split_times = i / num_split; split_times = i / num_split;
int split_size = split_sizes[split_which]; int split_size = split_sizes[split_which];
// support split size is -1 in the end.
if (split_size == -1) {
int split_dim_i = input_shape[split_dim];
for (int j = 0; j < num_split - 1; ++j) {
split_dim_i -= split_sizes[j];
}
split_size = split_dim_i;
}
float *dst = out_data[split_which] + split_times * in_stride * split_size; float *dst = out_data[split_which] + split_times * in_stride * split_size;
(void)memcpy(dst, src, split_size * in_stride_bytes); (void)memcpy(dst, src, split_size * in_stride_bytes);
src += split_size * in_stride; src += split_size * in_stride;

View File

@ -58,6 +58,14 @@ int SplitBaseCPUKernel::ReSize() {
} }
} }
if (param->split_sizes_[param->num_split_ - 1] == -1) {
int split_shape_end = input_shape[param->split_dim_];
for (int i = 0; i < param->num_split_ - 1; i++) {
split_shape_end -= param->split_sizes_[i];
}
param->split_sizes_[param->num_split_ - 1] = split_shape_end;
}
num_unit_ = param->split_count_ * param->num_split_; num_unit_ = param->split_count_ * param->num_split_;
thread_n_num_ = MSMIN(thread_count_, num_unit_); thread_n_num_ = MSMIN(thread_count_, num_unit_);
thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_); thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_);