forked from OSSInnovation/mindspore
!6839 Add functions that the split size can be -1 at the end of split ops
Merge pull request !6839 from liuwenhao/master
This commit is contained in:
commit
2ef51bdeae
|
@ -50,14 +50,6 @@ int DoSplit(float *in_data, float **out_data, const int *input_shape, int offset
|
||||||
split_which = i % num_split;
|
split_which = i % num_split;
|
||||||
split_times = i / num_split;
|
split_times = i / num_split;
|
||||||
int split_size = split_sizes[split_which];
|
int split_size = split_sizes[split_which];
|
||||||
// support split size is -1 in the end.
|
|
||||||
if (split_size == -1) {
|
|
||||||
int split_dim_i = input_shape[split_dim];
|
|
||||||
for (int j = 0; j < num_split - 1; ++j) {
|
|
||||||
split_dim_i -= split_sizes[j];
|
|
||||||
}
|
|
||||||
split_size = split_dim_i;
|
|
||||||
}
|
|
||||||
float *dst = out_data[split_which] + split_times * in_stride * split_size;
|
float *dst = out_data[split_which] + split_times * in_stride * split_size;
|
||||||
(void)memcpy(dst, src, split_size * in_stride_bytes);
|
(void)memcpy(dst, src, split_size * in_stride_bytes);
|
||||||
src += split_size * in_stride;
|
src += split_size * in_stride;
|
||||||
|
|
|
@ -58,6 +58,14 @@ int SplitBaseCPUKernel::ReSize() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (param->split_sizes_[param->num_split_ - 1] == -1) {
|
||||||
|
int split_shape_end = input_shape[param->split_dim_];
|
||||||
|
for (int i = 0; i < param->num_split_ - 1; i++) {
|
||||||
|
split_shape_end -= param->split_sizes_[i];
|
||||||
|
}
|
||||||
|
param->split_sizes_[param->num_split_ - 1] = split_shape_end;
|
||||||
|
}
|
||||||
|
|
||||||
num_unit_ = param->split_count_ * param->num_split_;
|
num_unit_ = param->split_count_ * param->num_split_;
|
||||||
thread_n_num_ = MSMIN(thread_count_, num_unit_);
|
thread_n_num_ = MSMIN(thread_count_, num_unit_);
|
||||||
thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_);
|
thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_);
|
||||||
|
|
Loading…
Reference in New Issue