!951 fix: MindDataset with columns_name parameter cause errors in some scenes

Merge pull request !951 from guozhijian/fix_read_by_columns
This commit is contained in:
mindspore-ci-bot 2020-05-08 04:51:03 +08:00 committed by Gitee
commit de7625777f
6 changed files with 1149 additions and 10 deletions

View File

@ -165,12 +165,22 @@ Status MindRecordOp::Init() {
Status MindRecordOp::SetColumnsBlob() {
columns_blob_ = shard_reader_->get_blob_fields().second;
// get the exactly blob fields by columns_to_load_
std::vector<std::string> columns_blob_exact;
for (auto &blob_field : columns_blob_) {
for (auto &column : columns_to_load_) {
if (column.compare(blob_field) == 0) {
columns_blob_exact.push_back(blob_field);
break;
}
}
}
columns_blob_index_ = std::vector<int32_t>(columns_to_load_.size(), -1);
int32_t iBlob = 0;
for (uint32_t i = 0; i < columns_blob_.size(); ++i) {
if (column_name_mapping_.count(columns_blob_[i])) {
columns_blob_index_[column_name_mapping_[columns_blob_[i]]] = iBlob++;
}
for (auto &blob_exact : columns_blob_exact) {
columns_blob_index_[column_name_mapping_[blob_exact]] = iBlob++;
}
return Status::OK();
}

View File

@ -294,6 +294,10 @@ class ShardReader {
/// \brief get number of classes
int64_t GetNumClasses(const std::string &file_path, const std::string &category_field);
/// \brief get exactly blob fields data by indices
std::vector<uint8_t> ExtractBlobFieldBySelectColumns(std::vector<uint8_t> &blob_fields_bytes,
std::vector<uint32_t> &ordered_selected_columns_index);
protected:
uint64_t header_size_; // header size
uint64_t page_size_; // page size

View File

@ -794,6 +794,8 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer,
n_consumer = kMinConsumerCount;
}
CheckNlp();
// dead code
if (nlp_) {
selected_columns_ = selected_columns;
} else {
@ -805,6 +807,7 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer,
}
}
}
selected_columns_ = selected_columns;
if (CheckColumnList(selected_columns_) == FAILED) {
MS_LOG(ERROR) << "Illegal column list";
@ -1064,6 +1067,36 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u
return SUCCESS;
}
std::vector<uint8_t> ShardReader::ExtractBlobFieldBySelectColumns(
std::vector<uint8_t> &blob_fields_bytes, std::vector<uint32_t> &ordered_selected_columns_index) {
std::vector<uint8_t> exactly_blob_fields_bytes;
auto uint64_from_bytes = [&](int64_t pos) {
uint64_t result = 0;
for (uint64_t n = 0; n < kInt64Len; n++) {
result = (result << 8) + blob_fields_bytes[pos + n];
}
return result;
};
// get the exactly blob fields
uint32_t current_index = 0;
uint64_t current_offset = 0;
uint64_t data_len = uint64_from_bytes(current_offset);
while (current_offset < blob_fields_bytes.size()) {
if (std::any_of(ordered_selected_columns_index.begin(), ordered_selected_columns_index.end(),
[&current_index](uint32_t &index) { return index == current_index; })) {
exactly_blob_fields_bytes.insert(exactly_blob_fields_bytes.end(), blob_fields_bytes.begin() + current_offset,
blob_fields_bytes.begin() + current_offset + kInt64Len + data_len);
}
current_index++;
current_offset += kInt64Len + data_len;
data_len = uint64_from_bytes(current_offset);
}
return exactly_blob_fields_bytes;
}
TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_id) {
// All tasks are done
if (task_id >= static_cast<int>(tasks_.Size())) {
@ -1081,6 +1114,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
return std::make_pair(FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>());
}
const std::shared_ptr<Page> &page = ret.second;
// Pack image list
std::vector<uint8_t> images(addr[1] - addr[0]);
auto file_offset = header_size_ + page_size_ * (page->get_page_id()) + addr[0];
@ -1100,10 +1134,42 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
return std::make_pair(FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>());
}
// extract the exactly blob bytes by selected columns
std::vector<uint8_t> images_with_exact_columns;
if (selected_columns_.size() == 0) {
images_with_exact_columns = images;
} else {
auto blob_fields = get_blob_fields();
std::vector<uint32_t> ordered_selected_columns_index;
uint32_t index = 0;
for (auto &blob_field : blob_fields.second) {
for (auto &field : selected_columns_) {
if (field.compare(blob_field) == 0) {
ordered_selected_columns_index.push_back(index);
break;
}
}
index++;
}
if (ordered_selected_columns_index.size() != 0) {
// extract the images
if (blob_fields.second.size() == 1) {
if (ordered_selected_columns_index.size() == 1) {
images_with_exact_columns = images;
}
} else {
images_with_exact_columns = ExtractBlobFieldBySelectColumns(images, ordered_selected_columns_index);
}
}
}
// Deliver batch data to output map
std::vector<std::tuple<std::vector<uint8_t>, json>> batch;
if (nlp_) {
json blob_fields = json::from_msgpack(images);
// dead code
json blob_fields = json::from_msgpack(images_with_exact_columns);
json merge;
if (selected_columns_.size() > 0) {
@ -1121,7 +1187,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
}
batch.emplace_back(std::vector<uint8_t>{}, std::move(merge));
} else {
batch.emplace_back(std::move(images), std::move(std::get<2>(task)));
batch.emplace_back(std::move(images_with_exact_columns), std::move(std::get<2>(task)));
}
return std::make_pair(SUCCESS, std::move(batch));
}

View File

@ -92,15 +92,25 @@ def populate_data(raw, blob, columns, blob_fields, schema):
if raw:
# remove dummy fileds
raw = {k: v for k, v in raw.items() if k in schema}
else:
raw = {}
if not blob_fields:
return raw
# Get the order preserving sequence of columns in blob
ordered_columns = []
if columns:
for blob_field in blob_fields:
if blob_field in columns:
ordered_columns.append(blob_field)
else:
ordered_columns = blob_fields
blob_bytes = bytes(blob)
def _render_raw(field, blob_data):
data_type = schema[field]['type']
data_shape = schema[field]['shape'] if 'shape' in schema[field] else []
if columns and field not in columns:
return
if data_shape:
try:
raw[field] = np.reshape(np.frombuffer(blob_data, dtype=data_type), data_shape)
@ -110,7 +120,9 @@ def populate_data(raw, blob, columns, blob_fields, schema):
raw[field] = blob_data
if len(blob_fields) == 1:
_render_raw(blob_fields[0], blob_bytes)
if len(ordered_columns) == 1:
_render_raw(blob_fields[0], blob_bytes)
return raw
return raw
def _int_from_bytes(xbytes: bytes) -> int:
@ -125,6 +137,6 @@ def populate_data(raw, blob, columns, blob_fields, schema):
start += 8
return blob_bytes[start : start + n_bytes]
for i, blob_field in enumerate(blob_fields):
for i, blob_field in enumerate(ordered_columns):
_render_raw(blob_field, _blob_at_position(i))
return raw

View File

@ -545,3 +545,597 @@ def inputs(vectors, maxlen=50):
mask = [1]*length + [0]*(maxlen-length)
segment = [0]*maxlen
return input_, mask, segment
def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
mindrecord_file_name = "test.mindrecord"
data = [{"file_name": "001.jpg", "label": 4,
"image1": bytes("image1 bytes abc", encoding='UTF-8'),
"image2": bytes("image1 bytes def", encoding='UTF-8'),
"source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64),
"source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"image3": bytes("image1 bytes ghi", encoding='UTF-8'),
"image4": bytes("image1 bytes jkl", encoding='UTF-8'),
"image5": bytes("image1 bytes mno", encoding='UTF-8'),
"target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64),
"target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64),
"target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
"target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)},
{"file_name": "002.jpg", "label": 5,
"image1": bytes("image2 bytes abc", encoding='UTF-8'),
"image2": bytes("image2 bytes def", encoding='UTF-8'),
"image3": bytes("image2 bytes ghi", encoding='UTF-8'),
"image4": bytes("image2 bytes jkl", encoding='UTF-8'),
"image5": bytes("image2 bytes mno", encoding='UTF-8'),
"source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64),
"source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64),
"target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64),
"target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
"target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)},
{"file_name": "003.jpg", "label": 6,
"source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64),
"source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64),
"target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64),
"target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
"image1": bytes("image3 bytes abc", encoding='UTF-8'),
"image2": bytes("image3 bytes def", encoding='UTF-8'),
"image3": bytes("image3 bytes ghi", encoding='UTF-8'),
"image4": bytes("image3 bytes jkl", encoding='UTF-8'),
"image5": bytes("image3 bytes mno", encoding='UTF-8'),
"target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)},
{"file_name": "004.jpg", "label": 7,
"source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64),
"source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"image1": bytes("image4 bytes abc", encoding='UTF-8'),
"image2": bytes("image4 bytes def", encoding='UTF-8'),
"image3": bytes("image4 bytes ghi", encoding='UTF-8'),
"image4": bytes("image4 bytes jkl", encoding='UTF-8'),
"image5": bytes("image4 bytes mno", encoding='UTF-8'),
"target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64),
"target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64),
"target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
"target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)},
{"file_name": "005.jpg", "label": 8,
"source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64),
"source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64),
"target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64),
"image1": bytes("image5 bytes abc", encoding='UTF-8'),
"image2": bytes("image5 bytes def", encoding='UTF-8'),
"image3": bytes("image5 bytes ghi", encoding='UTF-8'),
"image4": bytes("image5 bytes jkl", encoding='UTF-8'),
"image5": bytes("image5 bytes mno", encoding='UTF-8'),
"target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
"target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)},
{"file_name": "006.jpg", "label": 9,
"source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64),
"source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64),
"image1": bytes("image6 bytes abc", encoding='UTF-8'),
"image2": bytes("image6 bytes def", encoding='UTF-8'),
"image3": bytes("image6 bytes ghi", encoding='UTF-8'),
"image4": bytes("image6 bytes jkl", encoding='UTF-8'),
"image5": bytes("image6 bytes mno", encoding='UTF-8'),
"target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64),
"target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
"target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)}
]
writer = FileWriter(mindrecord_file_name)
schema = {"file_name": {"type": "string"},
"image1": {"type": "bytes"},
"image2": {"type": "bytes"},
"source_sos_ids": {"type": "int64", "shape": [-1]},
"source_sos_mask": {"type": "int64", "shape": [-1]},
"image3": {"type": "bytes"},
"image4": {"type": "bytes"},
"image5": {"type": "bytes"},
"target_sos_ids": {"type": "int64", "shape": [-1]},
"target_sos_mask": {"type": "int64", "shape": [-1]},
"target_eos_ids": {"type": "int64", "shape": [-1]},
"target_eos_mask": {"type": "int64", "shape": [-1]},
"label": {"type": "int32"}}
writer.add_schema(schema, "data is so cool")
writer.write_raw_data(data)
writer.commit()
# change data value to list
data_value_to_list = []
for item in data:
new_data = {}
new_data['file_name'] = np.asarray(list(bytes(item["file_name"], encoding='utf-8')), dtype=np.uint8)
new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32)
new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8)
new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8)
new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8)
new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8)
new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8)
new_data['source_sos_ids'] = item["source_sos_ids"]
new_data['source_sos_mask'] = item["source_sos_mask"]
new_data['target_sos_ids'] = item["target_sos_ids"]
new_data['target_sos_mask'] = item["target_sos_mask"]
new_data['target_eos_ids'] = item["target_eos_ids"]
new_data['target_eos_mask'] = item["target_eos_mask"]
data_value_to_list.append(new_data)
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 6
num_iter = 0
for item in data_set.create_dict_iterator():
assert len(item) == 13
for field in item:
if isinstance(item[field], np.ndarray):
assert (item[field] == data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 6
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
columns_list=["source_sos_ids", "source_sos_mask", "target_sos_ids"],
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 6
num_iter = 0
for item in data_set.create_dict_iterator():
assert len(item) == 3
for field in item:
if isinstance(item[field], np.ndarray):
assert (item[field] == data[num_iter][field]).all()
else:
assert item[field] == data[num_iter][field]
num_iter += 1
assert num_iter == 6
num_readers = 1
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
columns_list=["image2", "source_sos_mask", "image3", "target_sos_ids"],
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 6
num_iter = 0
for item in data_set.create_dict_iterator():
assert len(item) == 4
for field in item:
if isinstance(item[field], np.ndarray):
assert (item[field] == data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 6
num_readers = 3
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
columns_list=["target_sos_ids", "image4", "source_sos_ids"],
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 6
num_iter = 0
for item in data_set.create_dict_iterator():
assert len(item) == 3
for field in item:
if isinstance(item[field], np.ndarray):
assert (item[field] == data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 6
num_readers = 3
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
columns_list=["target_sos_ids", "image5", "image4", "image3", "source_sos_ids"],
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 6
num_iter = 0
for item in data_set.create_dict_iterator():
assert len(item) == 5
for field in item:
if isinstance(item[field], np.ndarray):
assert (item[field] == data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 6
num_readers = 1
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
columns_list=["target_eos_mask", "image5", "image2", "source_sos_mask", "label"],
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 6
num_iter = 0
for item in data_set.create_dict_iterator():
assert len(item) == 5
for field in item:
if isinstance(item[field], np.ndarray):
assert (item[field] == data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 6
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
columns_list=["label", "target_eos_mask", "image1", "target_eos_ids", "source_sos_mask",
"image2", "image4", "image3", "source_sos_ids", "image5", "file_name"],
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 6
num_iter = 0
for item in data_set.create_dict_iterator():
assert len(item) == 11
for field in item:
if isinstance(item[field], np.ndarray):
assert (item[field] == data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 6
os.remove("{}".format(mindrecord_file_name))
os.remove("{}.db".format(mindrecord_file_name))
def test_write_with_multi_bytes_and_MindDataset():
mindrecord_file_name = "test.mindrecord"
data = [{"file_name": "001.jpg", "label": 43,
"image1": bytes("image1 bytes abc", encoding='UTF-8'),
"image2": bytes("image1 bytes def", encoding='UTF-8'),
"image3": bytes("image1 bytes ghi", encoding='UTF-8'),
"image4": bytes("image1 bytes jkl", encoding='UTF-8'),
"image5": bytes("image1 bytes mno", encoding='UTF-8')},
{"file_name": "002.jpg", "label": 91,
"image1": bytes("image2 bytes abc", encoding='UTF-8'),
"image2": bytes("image2 bytes def", encoding='UTF-8'),
"image3": bytes("image2 bytes ghi", encoding='UTF-8'),
"image4": bytes("image2 bytes jkl", encoding='UTF-8'),
"image5": bytes("image2 bytes mno", encoding='UTF-8')},
{"file_name": "003.jpg", "label": 61,
"image1": bytes("image3 bytes abc", encoding='UTF-8'),
"image2": bytes("image3 bytes def", encoding='UTF-8'),
"image3": bytes("image3 bytes ghi", encoding='UTF-8'),
"image4": bytes("image3 bytes jkl", encoding='UTF-8'),
"image5": bytes("image3 bytes mno", encoding='UTF-8')},
{"file_name": "004.jpg", "label": 29,
"image1": bytes("image4 bytes abc", encoding='UTF-8'),
"image2": bytes("image4 bytes def", encoding='UTF-8'),
"image3": bytes("image4 bytes ghi", encoding='UTF-8'),
"image4": bytes("image4 bytes jkl", encoding='UTF-8'),
"image5": bytes("image4 bytes mno", encoding='UTF-8')},
{"file_name": "005.jpg", "label": 78,
"image1": bytes("image5 bytes abc", encoding='UTF-8'),
"image2": bytes("image5 bytes def", encoding='UTF-8'),
"image3": bytes("image5 bytes ghi", encoding='UTF-8'),
"image4": bytes("image5 bytes jkl", encoding='UTF-8'),
"image5": bytes("image5 bytes mno", encoding='UTF-8')},
{"file_name": "006.jpg", "label": 37,
"image1": bytes("image6 bytes abc", encoding='UTF-8'),
"image2": bytes("image6 bytes def", encoding='UTF-8'),
"image3": bytes("image6 bytes ghi", encoding='UTF-8'),
"image4": bytes("image6 bytes jkl", encoding='UTF-8'),
"image5": bytes("image6 bytes mno", encoding='UTF-8')}
]
writer = FileWriter(mindrecord_file_name)
schema = {"file_name": {"type": "string"},
"image1": {"type": "bytes"},
"image2": {"type": "bytes"},
"image3": {"type": "bytes"},
"label": {"type": "int32"},
"image4": {"type": "bytes"},
"image5": {"type": "bytes"}}
writer.add_schema(schema, "data is so cool")
writer.write_raw_data(data)
writer.commit()
# change data value to list
data_value_to_list = []
for item in data:
new_data = {}
new_data['file_name'] = np.asarray(list(bytes(item["file_name"], encoding='utf-8')), dtype=np.uint8)
new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32)
new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8)
new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8)
new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8)
new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8)
new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8)
data_value_to_list.append(new_data)
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 6
num_iter = 0
for item in data_set.create_dict_iterator():
assert len(item) == 7
for field in item:
if isinstance(item[field], np.ndarray):
assert (item[field] == data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 6
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
columns_list=["image1", "image2", "image5"],
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 6
num_iter = 0
for item in data_set.create_dict_iterator():
assert len(item) == 3
for field in item:
if isinstance(item[field], np.ndarray):
assert (item[field] == data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 6
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
columns_list=["image2", "image4"],
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 6
num_iter = 0
for item in data_set.create_dict_iterator():
assert len(item) == 2
for field in item:
if isinstance(item[field], np.ndarray):
assert (item[field] == data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 6
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
columns_list=["image5", "image2"],
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 6
num_iter = 0
for item in data_set.create_dict_iterator():
assert len(item) == 2
for field in item:
if isinstance(item[field], np.ndarray):
assert (item[field] == data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 6
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
columns_list=["image5", "image2", "label"],
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 6
num_iter = 0
for item in data_set.create_dict_iterator():
assert len(item) == 3
for field in item:
if isinstance(item[field], np.ndarray):
assert (item[field] == data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 6
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
columns_list=["image4", "image5", "image2", "image3", "file_name"],
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 6
num_iter = 0
for item in data_set.create_dict_iterator():
assert len(item) == 5
for field in item:
if isinstance(item[field], np.ndarray):
assert (item[field] == data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 6
os.remove("{}".format(mindrecord_file_name))
os.remove("{}.db".format(mindrecord_file_name))
def test_write_with_multi_array_and_MindDataset():
mindrecord_file_name = "test.mindrecord"
data = [{"source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64),
"source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"source_eos_ids": np.array([13, 14, 15, 16, 17, 18], dtype=np.int64),
"source_eos_mask": np.array([19, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
"target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64),
"target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64),
"target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
"target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)},
{"source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64),
"source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"source_eos_ids": np.array([113, 14, 15, 16, 17, 18], dtype=np.int64),
"source_eos_mask": np.array([119, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
"target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64),
"target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64),
"target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
"target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)},
{"source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64),
"source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"source_eos_ids": np.array([213, 14, 15, 16, 17, 18], dtype=np.int64),
"source_eos_mask": np.array([219, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
"target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64),
"target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64),
"target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
"target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)},
{"source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64),
"source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"source_eos_ids": np.array([313, 14, 15, 16, 17, 18], dtype=np.int64),
"source_eos_mask": np.array([319, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
"target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64),
"target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64),
"target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
"target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)},
{"source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64),
"source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"source_eos_ids": np.array([413, 14, 15, 16, 17, 18], dtype=np.int64),
"source_eos_mask": np.array([419, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
"target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64),
"target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64),
"target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
"target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)},
{"source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64),
"source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"source_eos_ids": np.array([513, 14, 15, 16, 17, 18], dtype=np.int64),
"source_eos_mask": np.array([519, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
"target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64),
"target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64),
"target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
"target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)}
]
writer = FileWriter(mindrecord_file_name)
schema = {"source_sos_ids": {"type": "int64", "shape": [-1]},
"source_sos_mask": {"type": "int64", "shape": [-1]},
"source_eos_ids": {"type": "int64", "shape": [-1]},
"source_eos_mask": {"type": "int64", "shape": [-1]},
"target_sos_ids": {"type": "int64", "shape": [-1]},
"target_sos_mask": {"type": "int64", "shape": [-1]},
"target_eos_ids": {"type": "int64", "shape": [-1]},
"target_eos_mask": {"type": "int64", "shape": [-1]}}
writer.add_schema(schema, "data is so cool")
writer.write_raw_data(data)
writer.commit()
# change data value to list - do none
data_value_to_list = []
for item in data:
new_data = {}
new_data['source_sos_ids'] = item["source_sos_ids"]
new_data['source_sos_mask'] = item["source_sos_mask"]
new_data['source_eos_ids'] = item["source_eos_ids"]
new_data['source_eos_mask'] = item["source_eos_mask"]
new_data['target_sos_ids'] = item["target_sos_ids"]
new_data['target_sos_mask'] = item["target_sos_mask"]
new_data['target_eos_ids'] = item["target_eos_ids"]
new_data['target_eos_mask'] = item["target_eos_mask"]
data_value_to_list.append(new_data)
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 6
num_iter = 0
for item in data_set.create_dict_iterator():
assert len(item) == 8
for field in item:
if isinstance(item[field], np.ndarray):
assert (item[field] == data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 6
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
columns_list=["source_eos_ids", "source_eos_mask",
"target_sos_ids", "target_sos_mask",
"target_eos_ids", "target_eos_mask"],
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 6
num_iter = 0
for item in data_set.create_dict_iterator():
assert len(item) == 6
for field in item:
if isinstance(item[field], np.ndarray):
assert (item[field] == data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 6
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
columns_list=["source_sos_ids",
"target_sos_ids",
"target_eos_mask"],
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 6
num_iter = 0
for item in data_set.create_dict_iterator():
assert len(item) == 3
for field in item:
if isinstance(item[field], np.ndarray):
assert (item[field] == data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 6
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
columns_list=["target_eos_mask",
"source_eos_mask",
"source_sos_mask"],
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 6
num_iter = 0
for item in data_set.create_dict_iterator():
assert len(item) == 3
for field in item:
if isinstance(item[field], np.ndarray):
assert (item[field] == data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 6
num_readers = 2
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
columns_list=["target_eos_ids"],
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 6
num_iter = 0
for item in data_set.create_dict_iterator():
assert len(item) == 1
for field in item:
if isinstance(item[field], np.ndarray):
assert (item[field] == data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 6
num_readers = 1
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
columns_list=["target_eos_mask", "target_eos_ids",
"target_sos_mask", "target_sos_ids",
"source_eos_mask", "source_eos_ids",
"source_sos_mask", "source_sos_ids"],
num_parallel_workers=num_readers,
shuffle=False)
assert data_set.get_dataset_size() == 6
num_iter = 0
for item in data_set.create_dict_iterator():
assert len(item) == 8
for field in item:
if isinstance(item[field], np.ndarray):
assert (item[field] == data_value_to_list[num_iter][field]).all()
else:
assert item[field] == data_value_to_list[num_iter][field]
num_iter += 1
assert num_iter == 6
os.remove("{}".format(mindrecord_file_name))
os.remove("{}.db".format(mindrecord_file_name))

View File

@ -448,3 +448,456 @@ def test_cv_file_writer_no_raw():
reader.close()
os.remove(NLP_FILE_NAME)
os.remove("{}.db".format(NLP_FILE_NAME))
def test_write_read_process_with_multi_bytes():
mindrecord_file_name = "test.mindrecord"
data = [{"file_name": "001.jpg", "label": 43,
"image1": bytes("image1 bytes abc", encoding='UTF-8'),
"image2": bytes("image1 bytes def", encoding='UTF-8'),
"image3": bytes("image1 bytes ghi", encoding='UTF-8'),
"image4": bytes("image1 bytes jkl", encoding='UTF-8'),
"image5": bytes("image1 bytes mno", encoding='UTF-8')},
{"file_name": "002.jpg", "label": 91,
"image1": bytes("image2 bytes abc", encoding='UTF-8'),
"image2": bytes("image2 bytes def", encoding='UTF-8'),
"image3": bytes("image2 bytes ghi", encoding='UTF-8'),
"image4": bytes("image2 bytes jkl", encoding='UTF-8'),
"image5": bytes("image2 bytes mno", encoding='UTF-8')},
{"file_name": "003.jpg", "label": 61,
"image1": bytes("image3 bytes abc", encoding='UTF-8'),
"image2": bytes("image3 bytes def", encoding='UTF-8'),
"image3": bytes("image3 bytes ghi", encoding='UTF-8'),
"image4": bytes("image3 bytes jkl", encoding='UTF-8'),
"image5": bytes("image3 bytes mno", encoding='UTF-8')},
{"file_name": "004.jpg", "label": 29,
"image1": bytes("image4 bytes abc", encoding='UTF-8'),
"image2": bytes("image4 bytes def", encoding='UTF-8'),
"image3": bytes("image4 bytes ghi", encoding='UTF-8'),
"image4": bytes("image4 bytes jkl", encoding='UTF-8'),
"image5": bytes("image4 bytes mno", encoding='UTF-8')},
{"file_name": "005.jpg", "label": 78,
"image1": bytes("image5 bytes abc", encoding='UTF-8'),
"image2": bytes("image5 bytes def", encoding='UTF-8'),
"image3": bytes("image5 bytes ghi", encoding='UTF-8'),
"image4": bytes("image5 bytes jkl", encoding='UTF-8'),
"image5": bytes("image5 bytes mno", encoding='UTF-8')},
{"file_name": "006.jpg", "label": 37,
"image1": bytes("image6 bytes abc", encoding='UTF-8'),
"image2": bytes("image6 bytes def", encoding='UTF-8'),
"image3": bytes("image6 bytes ghi", encoding='UTF-8'),
"image4": bytes("image6 bytes jkl", encoding='UTF-8'),
"image5": bytes("image6 bytes mno", encoding='UTF-8')}
]
writer = FileWriter(mindrecord_file_name)
schema = {"file_name": {"type": "string"},
"image1": {"type": "bytes"},
"image2": {"type": "bytes"},
"image3": {"type": "bytes"},
"label": {"type": "int32"},
"image4": {"type": "bytes"},
"image5": {"type": "bytes"}}
writer.add_schema(schema, "data is so cool")
writer.write_raw_data(data)
writer.commit()
reader = FileReader(mindrecord_file_name)
count = 0
for index, x in enumerate(reader.get_next()):
assert len(x) == 7
for field in x:
if isinstance(x[field], np.ndarray):
assert (x[field] == data[count][field]).all()
else:
assert x[field] == data[count][field]
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 6
reader.close()
reader2 = FileReader(file_name=mindrecord_file_name, columns=["image1", "image2", "image5"])
count = 0
for index, x in enumerate(reader2.get_next()):
assert len(x) == 3
for field in x:
if isinstance(x[field], np.ndarray):
assert (x[field] == data[count][field]).all()
else:
assert x[field] == data[count][field]
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 6
reader2.close()
reader3 = FileReader(file_name=mindrecord_file_name, columns=["image2", "image4"])
count = 0
for index, x in enumerate(reader3.get_next()):
assert len(x) == 2
for field in x:
if isinstance(x[field], np.ndarray):
assert (x[field] == data[count][field]).all()
else:
assert x[field] == data[count][field]
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 6
reader3.close()
reader4 = FileReader(file_name=mindrecord_file_name, columns=["image5", "image2"])
count = 0
for index, x in enumerate(reader4.get_next()):
assert len(x) == 2
for field in x:
if isinstance(x[field], np.ndarray):
assert (x[field] == data[count][field]).all()
else:
assert x[field] == data[count][field]
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 6
reader4.close()
reader5 = FileReader(file_name=mindrecord_file_name, columns=["image5", "image2", "label"])
count = 0
for index, x in enumerate(reader5.get_next()):
assert len(x) == 3
for field in x:
if isinstance(x[field], np.ndarray):
assert (x[field] == data[count][field]).all()
else:
assert x[field] == data[count][field]
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 6
reader5.close()
os.remove("{}".format(mindrecord_file_name))
os.remove("{}.db".format(mindrecord_file_name))
def test_write_read_process_with_multi_array():
mindrecord_file_name = "test.mindrecord"
data = [{"source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64),
"source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"source_eos_ids": np.array([13, 14, 15, 16, 17, 18], dtype=np.int64),
"source_eos_mask": np.array([19, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
"target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64),
"target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64),
"target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
"target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)},
{"source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64),
"source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"source_eos_ids": np.array([113, 14, 15, 16, 17, 18], dtype=np.int64),
"source_eos_mask": np.array([119, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
"target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64),
"target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64),
"target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
"target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)},
{"source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64),
"source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"source_eos_ids": np.array([213, 14, 15, 16, 17, 18], dtype=np.int64),
"source_eos_mask": np.array([219, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
"target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64),
"target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64),
"target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
"target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)},
{"source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64),
"source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"source_eos_ids": np.array([313, 14, 15, 16, 17, 18], dtype=np.int64),
"source_eos_mask": np.array([319, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
"target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64),
"target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64),
"target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
"target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)},
{"source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64),
"source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"source_eos_ids": np.array([413, 14, 15, 16, 17, 18], dtype=np.int64),
"source_eos_mask": np.array([419, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
"target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64),
"target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64),
"target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
"target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)},
{"source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64),
"source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"source_eos_ids": np.array([513, 14, 15, 16, 17, 18], dtype=np.int64),
"source_eos_mask": np.array([519, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
"target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64),
"target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64),
"target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
"target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)}
]
writer = FileWriter(mindrecord_file_name)
schema = {"source_sos_ids": {"type": "int64", "shape": [-1]},
"source_sos_mask": {"type": "int64", "shape": [-1]},
"source_eos_ids": {"type": "int64", "shape": [-1]},
"source_eos_mask": {"type": "int64", "shape": [-1]},
"target_sos_ids": {"type": "int64", "shape": [-1]},
"target_sos_mask": {"type": "int64", "shape": [-1]},
"target_eos_ids": {"type": "int64", "shape": [-1]},
"target_eos_mask": {"type": "int64", "shape": [-1]}}
writer.add_schema(schema, "data is so cool")
writer.write_raw_data(data)
writer.commit()
reader = FileReader(mindrecord_file_name)
count = 0
for index, x in enumerate(reader.get_next()):
assert len(x) == 8
for field in x:
if isinstance(x[field], np.ndarray):
assert (x[field] == data[count][field]).all()
else:
assert x[field] == data[count][field]
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 6
reader.close()
reader = FileReader(file_name=mindrecord_file_name, columns=["source_eos_ids", "source_eos_mask",
"target_sos_ids", "target_sos_mask",
"target_eos_ids", "target_eos_mask"])
count = 0
for index, x in enumerate(reader.get_next()):
assert len(x) == 6
for field in x:
if isinstance(x[field], np.ndarray):
assert (x[field] == data[count][field]).all()
else:
assert x[field] == data[count][field]
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 6
reader.close()
reader = FileReader(file_name=mindrecord_file_name, columns=["source_sos_ids",
"target_sos_ids",
"target_eos_mask"])
count = 0
for index, x in enumerate(reader.get_next()):
assert len(x) == 3
for field in x:
if isinstance(x[field], np.ndarray):
assert (x[field] == data[count][field]).all()
else:
assert x[field] == data[count][field]
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 6
reader.close()
reader = FileReader(file_name=mindrecord_file_name, columns=["target_eos_mask",
"source_eos_mask",
"source_sos_mask"])
count = 0
for index, x in enumerate(reader.get_next()):
assert len(x) == 3
for field in x:
if isinstance(x[field], np.ndarray):
assert (x[field] == data[count][field]).all()
else:
assert x[field] == data[count][field]
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 6
reader.close()
reader = FileReader(file_name=mindrecord_file_name, columns=["target_eos_ids"])
count = 0
for index, x in enumerate(reader.get_next()):
assert len(x) == 1
for field in x:
if isinstance(x[field], np.ndarray):
assert (x[field] == data[count][field]).all()
else:
assert x[field] == data[count][field]
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 6
reader.close()
os.remove("{}".format(mindrecord_file_name))
os.remove("{}.db".format(mindrecord_file_name))
def test_write_read_process_with_multi_bytes_and_array():
mindrecord_file_name = "test.mindrecord"
data = [{"file_name": "001.jpg", "label": 4,
"image1": bytes("image1 bytes abc", encoding='UTF-8'),
"image2": bytes("image1 bytes def", encoding='UTF-8'),
"source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64),
"source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"image3": bytes("image1 bytes ghi", encoding='UTF-8'),
"image4": bytes("image1 bytes jkl", encoding='UTF-8'),
"image5": bytes("image1 bytes mno", encoding='UTF-8'),
"target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64),
"target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64),
"target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
"target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)},
{"file_name": "002.jpg", "label": 5,
"image1": bytes("image2 bytes abc", encoding='UTF-8'),
"image2": bytes("image2 bytes def", encoding='UTF-8'),
"image3": bytes("image2 bytes ghi", encoding='UTF-8'),
"image4": bytes("image2 bytes jkl", encoding='UTF-8'),
"image5": bytes("image2 bytes mno", encoding='UTF-8'),
"source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64),
"source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64),
"target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64),
"target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
"target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)},
{"file_name": "003.jpg", "label": 6,
"source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64),
"source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64),
"target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64),
"target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
"image1": bytes("image3 bytes abc", encoding='UTF-8'),
"image2": bytes("image3 bytes def", encoding='UTF-8'),
"image3": bytes("image3 bytes ghi", encoding='UTF-8'),
"image4": bytes("image3 bytes jkl", encoding='UTF-8'),
"image5": bytes("image3 bytes mno", encoding='UTF-8'),
"target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)},
{"file_name": "004.jpg", "label": 7,
"source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64),
"source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"image1": bytes("image4 bytes abc", encoding='UTF-8'),
"image2": bytes("image4 bytes def", encoding='UTF-8'),
"image3": bytes("image4 bytes ghi", encoding='UTF-8'),
"image4": bytes("image4 bytes jkl", encoding='UTF-8'),
"image5": bytes("image4 bytes mno", encoding='UTF-8'),
"target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64),
"target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64),
"target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
"target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)},
{"file_name": "005.jpg", "label": 8,
"source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64),
"source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64),
"target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64),
"image1": bytes("image5 bytes abc", encoding='UTF-8'),
"image2": bytes("image5 bytes def", encoding='UTF-8'),
"image3": bytes("image5 bytes ghi", encoding='UTF-8'),
"image4": bytes("image5 bytes jkl", encoding='UTF-8'),
"image5": bytes("image5 bytes mno", encoding='UTF-8'),
"target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
"target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)},
{"file_name": "006.jpg", "label": 9,
"source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64),
"source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64),
"target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64),
"image1": bytes("image6 bytes abc", encoding='UTF-8'),
"image2": bytes("image6 bytes def", encoding='UTF-8'),
"image3": bytes("image6 bytes ghi", encoding='UTF-8'),
"image4": bytes("image6 bytes jkl", encoding='UTF-8'),
"image5": bytes("image6 bytes mno", encoding='UTF-8'),
"target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64),
"target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
"target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)}
]
writer = FileWriter(mindrecord_file_name)
schema = {"file_name": {"type": "string"},
"image1": {"type": "bytes"},
"image2": {"type": "bytes"},
"source_sos_ids": {"type": "int64", "shape": [-1]},
"source_sos_mask": {"type": "int64", "shape": [-1]},
"image3": {"type": "bytes"},
"image4": {"type": "bytes"},
"image5": {"type": "bytes"},
"target_sos_ids": {"type": "int64", "shape": [-1]},
"target_sos_mask": {"type": "int64", "shape": [-1]},
"target_eos_ids": {"type": "int64", "shape": [-1]},
"target_eos_mask": {"type": "int64", "shape": [-1]},
"label": {"type": "int32"}}
writer.add_schema(schema, "data is so cool")
writer.write_raw_data(data)
writer.commit()
reader = FileReader(mindrecord_file_name)
count = 0
for index, x in enumerate(reader.get_next()):
assert len(x) == 13
for field in x:
if isinstance(x[field], np.ndarray):
assert (x[field] == data[count][field]).all()
else:
assert x[field] == data[count][field]
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 6
reader.close()
reader = FileReader(file_name=mindrecord_file_name, columns=["source_sos_ids", "source_sos_mask",
"target_sos_ids"])
count = 0
for index, x in enumerate(reader.get_next()):
assert len(x) == 3
for field in x:
if isinstance(x[field], np.ndarray):
assert (x[field] == data[count][field]).all()
else:
assert x[field] == data[count][field]
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 6
reader.close()
reader = FileReader(file_name=mindrecord_file_name, columns=["image2", "source_sos_mask",
"image3", "target_sos_ids"])
count = 0
for index, x in enumerate(reader.get_next()):
assert len(x) == 4
for field in x:
if isinstance(x[field], np.ndarray):
assert (x[field] == data[count][field]).all()
else:
assert x[field] == data[count][field]
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 6
reader.close()
reader = FileReader(file_name=mindrecord_file_name, columns=["target_sos_ids", "image4",
"source_sos_ids"])
count = 0
for index, x in enumerate(reader.get_next()):
assert len(x) == 3
for field in x:
if isinstance(x[field], np.ndarray):
assert (x[field] == data[count][field]).all()
else:
assert x[field] == data[count][field]
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 6
reader.close()
reader = FileReader(file_name=mindrecord_file_name, columns=["target_sos_ids", "image5",
"image4", "image3", "source_sos_ids"])
count = 0
for index, x in enumerate(reader.get_next()):
assert len(x) == 5
for field in x:
if isinstance(x[field], np.ndarray):
assert (x[field] == data[count][field]).all()
else:
assert x[field] == data[count][field]
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 6
reader.close()
reader = FileReader(file_name=mindrecord_file_name, columns=["target_eos_mask", "image5", "image2",
"source_sos_mask", "label"])
count = 0
for index, x in enumerate(reader.get_next()):
assert len(x) == 5
for field in x:
if isinstance(x[field], np.ndarray):
assert (x[field] == data[count][field]).all()
else:
assert x[field] == data[count][field]
count = count + 1
logger.info("#item{}: {}".format(index, x))
assert count == 6
reader.close()
os.remove("{}".format(mindrecord_file_name))
os.remove("{}.db".format(mindrecord_file_name))