diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc
index 978d480aa5e..35b03726cda 100644
--- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc
+++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc
@@ -227,7 +227,9 @@ Status GeneratorOp::operator()() {
 
         // Restore exception to python
         e.restore();
-        if (num_rows_sampled != -1 && num_rows_sampled != generator_counter_) {
+
+        // Check whether the number of samples is sufficient only when the first epoch
+        if (num_rows_sampled != -1 && num_rows_sampled != generator_counter_ && op_current_epochs_ == 0) {
           if (generator_counter_ == 0) {
             std::string msg =
               "Unable to fetch data from GeneratorDataset, try iterate the source function of GeneratorDataset or check"
diff --git a/mindspore/python/mindspore/dataset/engine/datasets_user_defined.py b/mindspore/python/mindspore/dataset/engine/datasets_user_defined.py
index 23f78679f55..badedaad6f5 100644
--- a/mindspore/python/mindspore/dataset/engine/datasets_user_defined.py
+++ b/mindspore/python/mindspore/dataset/engine/datasets_user_defined.py
@@ -678,6 +678,10 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
         if hasattr(self.source, "__len__"):
             self.source_len = len(self.source)
 
+            # if user defined sampler, update the self.source_len
+            if isinstance(self.sampler, samplers.Sampler) or hasattr(self.sampler, "__iter__"):
+                self.source_len = len(list(sampler))
+
         self.max_rowsize = max_rowsize
         self.sample_fn = None
 
diff --git a/tests/ut/python/dataset/test_sampler.py b/tests/ut/python/dataset/test_sampler.py
index 3214593b385..11ecb870786 100644
--- a/tests/ut/python/dataset/test_sampler.py
+++ b/tests/ut/python/dataset/test_sampler.py
@@ -390,6 +390,173 @@ def test_sampler_list():
                  msg="Type of indices element must be int, but got list[0]: [1 2], type: <class 'numpy.ndarray'>.")
 
 
+def check_result(expected, result):
+    for index, item in enumerate(result):
+        assert str(expected[index][0]) == item
+
+
+def test_sampler_when_less_and_larger_index_ids():
+    """
+    Feature: Sampler op
+    Description: Test sampler with less and larger index ids
+    Expectation: success
+    """
+
+    # sampler with less index ids
+    class MySampler():
+        def __iter__(self):
+            for i in range(0, 10, 2):
+                yield i
+
+    np_data = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l']
+
+    dataset = ds.NumpySlicesDataset(np_data, column_names=["data"], sampler=MySampler())
+    count = 0
+    expected_data = []
+    for data in dataset.create_tuple_iterator(num_epochs=1):
+        count += 1
+        expected_data.append(data)
+    assert count == 5
+    check_result(expected_data, ['a', 'c', 'e', 'g', 'i'])
+
+    epochs = 3
+    ds_iter = dataset.create_tuple_iterator(num_epochs=epochs)
+    for _ in range(epochs):
+        count = 0
+        expected_data = []
+        for data in ds_iter:
+            count += 1
+            expected_data.append(data)
+        assert count == 5
+        check_result(expected_data, ['a', 'c', 'e', 'g', 'i'])
+
+    # sampler with larger index ids
+    index = [3, 4, 3, 2, 0, 11, 5, 5, 5, 9, 1, 11, 11, 11, 11, 8]
+    class MySampler2():
+        def __iter__(self):
+            for i in index:
+                yield i
+
+    dataset2 = ds.NumpySlicesDataset(np_data, column_names=["data"], sampler=MySampler2())
+    count = 0
+    expected_data = []
+    for data in dataset2.create_tuple_iterator(num_epochs=1):
+        count += 1
+        expected_data.append(data)
+    assert count == 16
+    check_result(expected_data, ['d', 'e', 'd', 'c', 'a', 'l', 'f', 'f', 'f', 'j', 'b', 'l', 'l', 'l', 'l', 'i'])
+
+    epochs = 3
+    ds_iter2 = dataset2.create_tuple_iterator(num_epochs=epochs)
+    for _ in range(epochs):
+        count = 0
+        expected_data = []
+        for data in ds_iter2:
+            count += 1
+            expected_data.append(data)
+        assert count == 16
+        check_result(expected_data, ['d', 'e', 'd', 'c', 'a', 'l', 'f', 'f', 'f', 'j', 'b', 'l', 'l', 'l', 'l', 'i'])
+
+
+def test_sampler_with_getitem_method():
+    """
+    Feature: Sampler op
+    Description: Test sampler with __getitem__ method
+    Expectation: success
+    """
+
+    # sampler with equal index ids
+    class MySampler():
+        def __init__(self):
+            self.index_ids = [3, 8, 7, 2, 0, 9, 11, 4, 5, 1, 6, 10]
+        def __getitem__(self, index):
+            return self.index_ids[index]
+        def __len__(self):
+            return len(self.index_ids)
+
+    np_data = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l']
+
+    dataset = ds.NumpySlicesDataset(np_data, column_names=["data"], sampler=MySampler())
+    count = 0
+    expected_data = []
+    for data in dataset.create_tuple_iterator(num_epochs=1):
+        count += 1
+        expected_data.append(data)
+    assert count == 12
+    check_result(expected_data, ['d', 'i', 'h', 'c', 'a', 'j', 'l', 'e', 'f', 'b', 'g', 'k'])
+
+    epochs = 3
+    ds_iter = dataset.create_tuple_iterator(num_epochs=epochs)
+    for _ in range(epochs):
+        count = 0
+        expected_data = []
+        for data in ds_iter:
+            count += 1
+            expected_data.append(data)
+        assert count == 12
+        check_result(expected_data, ['d', 'i', 'h', 'c', 'a', 'j', 'l', 'e', 'f', 'b', 'g', 'k'])
+
+    # sampler with less index ids
+    class MySampler2():
+        def __init__(self):
+            self.index_ids = [0, 2, 4, 6, 8]
+        def __getitem__(self, index):
+            return self.index_ids[index]
+        def __len__(self):
+            return len(self.index_ids)
+
+    np_data = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l']
+
+    dataset2 = ds.NumpySlicesDataset(np_data, column_names=["data"], sampler=MySampler2())
+    count = 0
+    expected_data = []
+    for data in dataset2.create_tuple_iterator(num_epochs=1):
+        count += 1
+        expected_data.append(data)
+    assert count == 5
+    check_result(expected_data, ['a', 'c', 'e', 'g', 'i'])
+
+    epochs = 3
+    ds_iter2 = dataset2.create_tuple_iterator(num_epochs=epochs)
+    for _ in range(epochs):
+        count = 0
+        expected_data = []
+        for data in ds_iter2:
+            count += 1
+            expected_data.append(data)
+        assert count == 5
+        check_result(expected_data, ['a', 'c', 'e', 'g', 'i'])
+
+    # sampler with larger index ids
+    class MySampler3():
+        def __init__(self):
+            self.index_ids = [3, 4, 3, 2, 0, 11, 5, 5, 5, 9, 1, 11, 11, 11, 11, 8]
+        def __getitem__(self, index):
+            return self.index_ids[index]
+        def __len__(self):
+            return len(self.index_ids)
+
+    dataset3 = ds.NumpySlicesDataset(np_data, column_names=["data"], sampler=MySampler3())
+    count = 0
+    expected_data = []
+    for data in dataset3.create_tuple_iterator(num_epochs=1):
+        count += 1
+        expected_data.append(data)
+    assert count == 16
+    check_result(expected_data, ['d', 'e', 'd', 'c', 'a', 'l', 'f', 'f', 'f', 'j', 'b', 'l', 'l', 'l', 'l', 'i'])
+
+    epochs = 3
+    ds_iter3 = dataset3.create_tuple_iterator(num_epochs=epochs)
+    for _ in range(epochs):
+        count = 0
+        expected_data = []
+        for data in ds_iter3:
+            count += 1
+            expected_data.append(data)
+        assert count == 16
+        check_result(expected_data, ['d', 'e', 'd', 'c', 'a', 'l', 'f', 'f', 'f', 'j', 'b', 'l', 'l', 'l', 'l', 'i'])
+
+
 if __name__ == '__main__':
     test_sequential_sampler(True)
     test_random_sampler(True)
@@ -402,3 +569,5 @@ if __name__ == '__main__':
     test_add_sampler_invalid_input()
     test_distributed_sampler_invalid_offset()
     test_sampler_list()
+    test_sampler_when_less_and_larger_index_ids()
+    test_sampler_with_getitem_method()