forked from mindspore-Ecosystem/mindspore
fix tfrecord_op core dump
add test case fix ci fix ci round 2 address review cmts
This commit is contained in:
parent
9343b4f5c3
commit
ea97197311
|
@ -27,18 +27,14 @@
|
|||
|
||||
#include "proto/example.pb.h"
|
||||
#include "./securec.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/engine/connector.h"
|
||||
#include "minddata/dataset/engine/data_schema.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/io_block.h"
|
||||
#include "minddata/dataset/engine/db_connector.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/jagged_connector.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "minddata/dataset/util/queue.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/util/task_manager.h"
|
||||
|
@ -387,14 +383,14 @@ Status TFReaderOp::PostEndOfEpoch(int32_t queue_index) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool TFReaderOp::NeedPushFileToblockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
|
||||
bool TFReaderOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
|
||||
const int64_t &pre_count) {
|
||||
*start_offset = 0;
|
||||
*end_offset = 0;
|
||||
bool push = false;
|
||||
int64_t start_index = device_id_ * num_rows_per_shard_;
|
||||
if (device_id_ + 1 < 0) {
|
||||
MS_LOG(ERROR) << "Device id is invalid";
|
||||
MS_LOG(ERROR) << "Device id is invalid.";
|
||||
return false;
|
||||
}
|
||||
int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_;
|
||||
|
@ -448,7 +444,7 @@ Status TFReaderOp::FillIOBlockShuffle(const std::vector<int64_t> &i_keys) {
|
|||
} else {
|
||||
// Do an index lookup using that key to get the filename.
|
||||
std::string file_name = (*filename_index_)[*it];
|
||||
if (NeedPushFileToblockQueue(file_name, &start_offset, &end_offset, pre_count)) {
|
||||
if (NeedPushFileToBlockQueue(file_name, &start_offset, &end_offset, pre_count)) {
|
||||
auto ioBlock = std::make_unique<FilenameBlock>(*it, start_offset, end_offset, IOBlock::kDeIoBlockNone);
|
||||
RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
|
||||
MS_LOG(DEBUG) << "File name " << *it << " start offset " << start_offset << " end_offset " << end_offset;
|
||||
|
@ -496,7 +492,7 @@ Status TFReaderOp::FillIOBlockNoShuffle() {
|
|||
}
|
||||
} else {
|
||||
std::string file_name = it.value();
|
||||
if (NeedPushFileToblockQueue(file_name, &start_offset, &end_offset, pre_count)) {
|
||||
if (NeedPushFileToBlockQueue(file_name, &start_offset, &end_offset, pre_count)) {
|
||||
auto ioBlock = std::make_unique<FilenameBlock>(it.key(), start_offset, end_offset, IOBlock::kDeIoBlockNone);
|
||||
RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
|
||||
queue_index = (queue_index + 1) % num_workers_;
|
||||
|
@ -711,7 +707,7 @@ Status TFReaderOp::LoadFeature(const std::unique_ptr<TensorQTable> *tensor_table
|
|||
// reinitializes itself so that it can be executed again, as if it was just created.
|
||||
Status TFReaderOp::Reset() {
|
||||
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
|
||||
// start workers first, otherwise IOBlokcs will fall through if workers see it before this is set to true
|
||||
// start workers first, otherwise IOBlocks will fall through if workers see it before this is set to true
|
||||
load_jagged_connector_ = true;
|
||||
|
||||
{
|
||||
|
@ -767,6 +763,14 @@ Status TFReaderOp::LoadBytesList(const ColDescriptor ¤t_col, const dataeng
|
|||
new_pad_size *= cur_shape[i];
|
||||
}
|
||||
pad_size = new_pad_size;
|
||||
} else {
|
||||
if (cur_shape.known() && cur_shape.NumOfElements() != max_size) {
|
||||
std::string err_msg = "Shape in schema's column '" + current_col.name() + "' is incorrect." +
|
||||
"\nshape received: " + cur_shape.ToString() +
|
||||
"\ntotal elements in shape received: " + std::to_string(cur_shape.NumOfElements()) +
|
||||
"\nexpected total elements in shape: " + std::to_string(max_size);
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -387,7 +387,7 @@ class TFReaderOp : public ParallelOp {
|
|||
// @param end_file - If file contains the end sample of data.
|
||||
// @param pre_count - Total rows of previous files.
|
||||
// @return Status - the error code returned.
|
||||
bool NeedPushFileToblockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
|
||||
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
|
||||
const int64_t &pre_count);
|
||||
|
||||
// Caculate number of rows in each shard.
|
||||
|
|
|
@ -491,3 +491,17 @@ TEST_F(MindDataTestPipeline, TestTFRecordDatasetExeception2) {
|
|||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestIncorrectTFSchemaObject) {
|
||||
std::string path = datasets_root_path_ + "/test_tf_file_3_images2/train-0000-of-0001.data";
|
||||
std::shared_ptr<api::SchemaObj> schema = api::Schema();
|
||||
schema->add_column("image", "uint8", {1});
|
||||
schema->add_column("label", "int64", {1});
|
||||
std::shared_ptr<api::Dataset> ds = api::TFRecord({path}, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
auto itr = ds->CreateIterator();
|
||||
EXPECT_NE(itr, nullptr);
|
||||
TensorMap mp;
|
||||
// this will fail due to the incorrect schema used
|
||||
EXPECT_FALSE(itr->GetNextRow(&mp));
|
||||
}
|
||||
|
|
|
@ -294,6 +294,24 @@ def test_tfrecord_invalid_files():
|
|||
assert nonexistent_file in str(info.value)
|
||||
|
||||
|
||||
def test_tf_wrong_schema():
|
||||
logger.info("test_tf_wrong_schema")
|
||||
files = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data"]
|
||||
schema = ds.Schema()
|
||||
schema.add_column('image', de_type=mstype.uint8, shape=[1])
|
||||
schema.add_column('label', de_type=mstype.int64, shape=[1])
|
||||
data1 = ds.TFRecordDataset(files, schema, shuffle=False)
|
||||
exception_occurred = False
|
||||
try:
|
||||
for _ in data1:
|
||||
pass
|
||||
except RuntimeError as e:
|
||||
exception_occurred = True
|
||||
assert "Shape in schema's column 'image' is incorrect" in str(e)
|
||||
|
||||
assert exception_occurred, "test_tf_wrong_schema failed."
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_tfrecord_shape()
|
||||
test_tfrecord_read_all_dataset()
|
||||
|
@ -312,3 +330,4 @@ if __name__ == '__main__':
|
|||
test_tfrecord_no_schema_columns_list()
|
||||
test_tfrecord_schema_columns_list()
|
||||
test_tfrecord_invalid_files()
|
||||
test_tf_wrong_schema()
|
||||
|
|
Loading…
Reference in New Issue