forked from mindspore-Ecosystem/mindspore
!2398 Bug in Slice when multiple rows are used
Merge pull request !2398 from h.farahat/slice_bug
This commit is contained in:
commit
beb436f457
|
@ -33,8 +33,8 @@ Status SliceOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te
|
|||
// if slice object was provided, indices should be empty. Generate indices from the slice object.
|
||||
if (slice_.valid() && indices_.empty()) {
|
||||
dsize_t len = input->shape()[0];
|
||||
indices_ = slice_.Indices(len);
|
||||
return input->Slice(output, indices_);
|
||||
std::vector<dsize_t> indices = slice_.Indices(len);
|
||||
return input->Slice(output, indices);
|
||||
}
|
||||
|
||||
// if indices are not empty, slices should be invalid, use indices_ to slice
|
||||
|
|
|
@ -80,6 +80,22 @@ def test_slice_slice_obj_3s():
|
|||
slice_compare([1, 2, 3, 4, 5], slice(2, 5, 3))
|
||||
|
||||
|
||||
def test_slice_multiple_rows():
|
||||
dataset = [[1, 2], [3, 4, 5], [1], [1, 2, 3, 4, 5, 6, 7]]
|
||||
|
||||
def gen():
|
||||
for row in dataset:
|
||||
yield (np.array(row),)
|
||||
|
||||
data = ds.GeneratorDataset(gen, column_names=["col"])
|
||||
indexing = slice(0, 4)
|
||||
data = data.map(operations=ops.Slice(indexing))
|
||||
for i, d in enumerate(data):
|
||||
array = np.array(dataset[i])
|
||||
array = array[indexing]
|
||||
np.testing.assert_array_equal(array, d[0])
|
||||
|
||||
|
||||
def test_slice_slice_obj_3s_double():
|
||||
slice_compare([1., 2., 3., 4., 5.], slice(0, 2, 1))
|
||||
slice_compare([1., 2., 3., 4., 5.], slice(0, 4, 1))
|
||||
|
@ -217,3 +233,4 @@ if __name__ == "__main__":
|
|||
test_slice_slice_obj_1s_str()
|
||||
test_slice_slice_obj_neg_str()
|
||||
test_slice_exceptions_str()
|
||||
test_slice_multiple_rows()
|
||||
|
|
Loading…
Reference in New Issue