forked from OSSInnovation/mindspore
change GetTensor into GetRow to avoid NullPtr
This commit is contained in:
parent
9ea10a0022
commit
7d2fe8c279
|
@ -91,11 +91,14 @@ void Sampler::Print(std::ostream &out, bool show_all) const {
|
||||||
Status Sampler::GetAllIdsThenReset(py::array *data) {
|
Status Sampler::GetAllIdsThenReset(py::array *data) {
|
||||||
std::unique_ptr<DataBuffer> db;
|
std::unique_ptr<DataBuffer> db;
|
||||||
std::shared_ptr<Tensor> sample_ids;
|
std::shared_ptr<Tensor> sample_ids;
|
||||||
|
TensorRow sample_row;
|
||||||
|
|
||||||
// A call to derived class to get sample ids wrapped inside a buffer
|
// A call to derived class to get sample ids wrapped inside a buffer
|
||||||
RETURN_IF_NOT_OK(GetNextSample(&db));
|
RETURN_IF_NOT_OK(GetNextSample(&db));
|
||||||
// Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch
|
// Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch
|
||||||
RETURN_IF_NOT_OK(db->GetTensor(&sample_ids, 0, 0));
|
RETURN_IF_NOT_OK(db->GetRow(0, &sample_row));
|
||||||
|
sample_ids = sample_row[0];
|
||||||
|
|
||||||
// check this buffer is not a ctrl buffer
|
// check this buffer is not a ctrl buffer
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(db->buffer_flags() == DataBuffer::kDeBFlagNone, "ERROR ctrl buffer received");
|
CHECK_FAIL_RETURN_UNEXPECTED(db->buffer_flags() == DataBuffer::kDeBFlagNone, "ERROR ctrl buffer received");
|
||||||
{
|
{
|
||||||
|
|
|
@ -125,7 +125,6 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpDefault) {
|
||||||
EXPECT_TRUE(rc.IsOk());
|
EXPECT_TRUE(rc.IsOk());
|
||||||
row_count++;
|
row_count++;
|
||||||
}
|
}
|
||||||
MS_LOG(WARNING) <<"row count is: " << row_count;
|
|
||||||
ASSERT_EQ(row_count, 3); // Should be 3 rows fetched
|
ASSERT_EQ(row_count, 3); // Should be 3 rows fetched
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue