From 7d2fe8c27912daa24513c2e16a3404874580613b Mon Sep 17 00:00:00 2001 From: ms_yan <6576637+ms_yan@user.noreply.gitee.com> Date: Wed, 24 Jun 2020 11:55:17 +0800 Subject: [PATCH] change GetTensor into GetRow to avoid NullPtr --- .../dataset/engine/datasetops/source/sampler/sampler.cc | 5 ++++- tests/ut/cpp/dataset/zip_op_test.cc | 1 - 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc index 3f737c167cd..b3c595870f8 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc @@ -91,11 +91,14 @@ void Sampler::Print(std::ostream &out, bool show_all) const { Status Sampler::GetAllIdsThenReset(py::array *data) { std::unique_ptr db; std::shared_ptr sample_ids; + TensorRow sample_row; // A call to derived class to get sample ids wrapped inside a buffer RETURN_IF_NOT_OK(GetNextSample(&db)); // Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch - RETURN_IF_NOT_OK(db->GetTensor(&sample_ids, 0, 0)); + RETURN_IF_NOT_OK(db->GetRow(0, &sample_row)); + sample_ids = sample_row[0]; + // check this buffer is not a ctrl buffer CHECK_FAIL_RETURN_UNEXPECTED(db->buffer_flags() == DataBuffer::kDeBFlagNone, "ERROR ctrl buffer received"); { diff --git a/tests/ut/cpp/dataset/zip_op_test.cc b/tests/ut/cpp/dataset/zip_op_test.cc index 7885369c075..f8f8fe89db5 100644 --- a/tests/ut/cpp/dataset/zip_op_test.cc +++ b/tests/ut/cpp/dataset/zip_op_test.cc @@ -125,7 +125,6 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpDefault) { EXPECT_TRUE(rc.IsOk()); row_count++; } - MS_LOG(WARNING) <<"row count is: " << row_count; ASSERT_EQ(row_count, 3); // Should be 3 rows fetched }