!2398 Bug in Slice when multiple rows are used

Merge pull request !2398 from h.farahat/slice_bug
This commit is contained in:
mindspore-ci-bot 2020-06-22 09:23:44 +08:00 committed by Gitee
commit beb436f457
2 changed files with 19 additions and 2 deletions

View File

@ -33,8 +33,8 @@ Status SliceOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te
// if slice object was provided, indices should be empty. Generate indices from the slice object.
if (slice_.valid() && indices_.empty()) {
dsize_t len = input->shape()[0];
indices_ = slice_.Indices(len);
return input->Slice(output, indices_);
std::vector<dsize_t> indices = slice_.Indices(len);
return input->Slice(output, indices);
}
// if indices are not empty, slices should be invalid, use indices_ to slice

View File

@ -80,6 +80,22 @@ def test_slice_slice_obj_3s():
slice_compare([1, 2, 3, 4, 5], slice(2, 5, 3))
def test_slice_multiple_rows():
dataset = [[1, 2], [3, 4, 5], [1], [1, 2, 3, 4, 5, 6, 7]]
def gen():
for row in dataset:
yield (np.array(row),)
data = ds.GeneratorDataset(gen, column_names=["col"])
indexing = slice(0, 4)
data = data.map(operations=ops.Slice(indexing))
for i, d in enumerate(data):
array = np.array(dataset[i])
array = array[indexing]
np.testing.assert_array_equal(array, d[0])
def test_slice_slice_obj_3s_double():
slice_compare([1., 2., 3., 4., 5.], slice(0, 2, 1))
slice_compare([1., 2., 3., 4., 5.], slice(0, 4, 1))
@ -217,3 +233,4 @@ if __name__ == "__main__":
test_slice_slice_obj_1s_str()
test_slice_slice_obj_neg_str()
test_slice_exceptions_str()
test_slice_multiple_rows()