forked from mindspore-Ecosystem/mindspore
fix: use MindDataset by column_names get data error in some situation
This commit is contained in:
parent
d004ef2234
commit
d4d236bcce
|
@ -165,12 +165,22 @@ Status MindRecordOp::Init() {
|
|||
|
||||
Status MindRecordOp::SetColumnsBlob() {
|
||||
columns_blob_ = shard_reader_->get_blob_fields().second;
|
||||
|
||||
// get the exactly blob fields by columns_to_load_
|
||||
std::vector<std::string> columns_blob_exact;
|
||||
for (auto &blob_field : columns_blob_) {
|
||||
for (auto &column : columns_to_load_) {
|
||||
if (column.compare(blob_field) == 0) {
|
||||
columns_blob_exact.push_back(blob_field);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
columns_blob_index_ = std::vector<int32_t>(columns_to_load_.size(), -1);
|
||||
int32_t iBlob = 0;
|
||||
for (uint32_t i = 0; i < columns_blob_.size(); ++i) {
|
||||
if (column_name_mapping_.count(columns_blob_[i])) {
|
||||
columns_blob_index_[column_name_mapping_[columns_blob_[i]]] = iBlob++;
|
||||
}
|
||||
for (auto &blob_exact : columns_blob_exact) {
|
||||
columns_blob_index_[column_name_mapping_[blob_exact]] = iBlob++;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -294,6 +294,10 @@ class ShardReader {
|
|||
/// \brief get number of classes
|
||||
int64_t GetNumClasses(const std::string &file_path, const std::string &category_field);
|
||||
|
||||
/// \brief get exactly blob fields data by indices
|
||||
std::vector<uint8_t> ExtractBlobFieldBySelectColumns(std::vector<uint8_t> &blob_fields_bytes,
|
||||
std::vector<uint32_t> &ordered_selected_columns_index);
|
||||
|
||||
protected:
|
||||
uint64_t header_size_; // header size
|
||||
uint64_t page_size_; // page size
|
||||
|
|
|
@ -790,6 +790,8 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer,
|
|||
n_consumer = kMinConsumerCount;
|
||||
}
|
||||
CheckNlp();
|
||||
|
||||
// dead code
|
||||
if (nlp_) {
|
||||
selected_columns_ = selected_columns;
|
||||
} else {
|
||||
|
@ -801,6 +803,7 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer,
|
|||
}
|
||||
}
|
||||
}
|
||||
selected_columns_ = selected_columns;
|
||||
|
||||
if (CheckColumnList(selected_columns_) == FAILED) {
|
||||
MS_LOG(ERROR) << "Illegal column list";
|
||||
|
@ -1060,6 +1063,36 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> ShardReader::ExtractBlobFieldBySelectColumns(
|
||||
std::vector<uint8_t> &blob_fields_bytes, std::vector<uint32_t> &ordered_selected_columns_index) {
|
||||
std::vector<uint8_t> exactly_blob_fields_bytes;
|
||||
|
||||
auto uint64_from_bytes = [&](int64_t pos) {
|
||||
uint64_t result = 0;
|
||||
for (uint64_t n = 0; n < kInt64Len; n++) {
|
||||
result = (result << 8) + blob_fields_bytes[pos + n];
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
// get the exactly blob fields
|
||||
uint32_t current_index = 0;
|
||||
uint64_t current_offset = 0;
|
||||
uint64_t data_len = uint64_from_bytes(current_offset);
|
||||
while (current_offset < blob_fields_bytes.size()) {
|
||||
if (std::any_of(ordered_selected_columns_index.begin(), ordered_selected_columns_index.end(),
|
||||
[¤t_index](uint32_t &index) { return index == current_index; })) {
|
||||
exactly_blob_fields_bytes.insert(exactly_blob_fields_bytes.end(), blob_fields_bytes.begin() + current_offset,
|
||||
blob_fields_bytes.begin() + current_offset + kInt64Len + data_len);
|
||||
}
|
||||
current_index++;
|
||||
current_offset += kInt64Len + data_len;
|
||||
data_len = uint64_from_bytes(current_offset);
|
||||
}
|
||||
|
||||
return exactly_blob_fields_bytes;
|
||||
}
|
||||
|
||||
TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_id) {
|
||||
// All tasks are done
|
||||
if (task_id >= static_cast<int>(tasks_.Size())) {
|
||||
|
@ -1077,6 +1110,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
|
|||
return std::make_pair(FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>());
|
||||
}
|
||||
const std::shared_ptr<Page> &page = ret.second;
|
||||
|
||||
// Pack image list
|
||||
std::vector<uint8_t> images(addr[1] - addr[0]);
|
||||
auto file_offset = header_size_ + page_size_ * (page->get_page_id()) + addr[0];
|
||||
|
@ -1096,10 +1130,42 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
|
|||
return std::make_pair(FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>());
|
||||
}
|
||||
|
||||
// extract the exactly blob bytes by selected columns
|
||||
std::vector<uint8_t> images_with_exact_columns;
|
||||
if (selected_columns_.size() == 0) {
|
||||
images_with_exact_columns = images;
|
||||
} else {
|
||||
auto blob_fields = get_blob_fields();
|
||||
|
||||
std::vector<uint32_t> ordered_selected_columns_index;
|
||||
uint32_t index = 0;
|
||||
for (auto &blob_field : blob_fields.second) {
|
||||
for (auto &field : selected_columns_) {
|
||||
if (field.compare(blob_field) == 0) {
|
||||
ordered_selected_columns_index.push_back(index);
|
||||
break;
|
||||
}
|
||||
}
|
||||
index++;
|
||||
}
|
||||
|
||||
if (ordered_selected_columns_index.size() != 0) {
|
||||
// extract the images
|
||||
if (blob_fields.second.size() == 1) {
|
||||
if (ordered_selected_columns_index.size() == 1) {
|
||||
images_with_exact_columns = images;
|
||||
}
|
||||
} else {
|
||||
images_with_exact_columns = ExtractBlobFieldBySelectColumns(images, ordered_selected_columns_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Deliver batch data to output map
|
||||
std::vector<std::tuple<std::vector<uint8_t>, json>> batch;
|
||||
if (nlp_) {
|
||||
json blob_fields = json::from_msgpack(images);
|
||||
// dead code
|
||||
json blob_fields = json::from_msgpack(images_with_exact_columns);
|
||||
|
||||
json merge;
|
||||
if (selected_columns_.size() > 0) {
|
||||
|
@ -1117,7 +1183,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
|
|||
}
|
||||
batch.emplace_back(std::vector<uint8_t>{}, std::move(merge));
|
||||
} else {
|
||||
batch.emplace_back(std::move(images), std::move(std::get<2>(task)));
|
||||
batch.emplace_back(std::move(images_with_exact_columns), std::move(std::get<2>(task)));
|
||||
}
|
||||
return std::make_pair(SUCCESS, std::move(batch));
|
||||
}
|
||||
|
|
|
@ -92,15 +92,25 @@ def populate_data(raw, blob, columns, blob_fields, schema):
|
|||
if raw:
|
||||
# remove dummy fileds
|
||||
raw = {k: v for k, v in raw.items() if k in schema}
|
||||
else:
|
||||
raw = {}
|
||||
if not blob_fields:
|
||||
return raw
|
||||
|
||||
# Get the order preserving sequence of columns in blob
|
||||
ordered_columns = []
|
||||
if columns:
|
||||
for blob_field in blob_fields:
|
||||
if blob_field in columns:
|
||||
ordered_columns.append(blob_field)
|
||||
else:
|
||||
ordered_columns = blob_fields
|
||||
|
||||
blob_bytes = bytes(blob)
|
||||
|
||||
def _render_raw(field, blob_data):
|
||||
data_type = schema[field]['type']
|
||||
data_shape = schema[field]['shape'] if 'shape' in schema[field] else []
|
||||
if columns and field not in columns:
|
||||
return
|
||||
if data_shape:
|
||||
try:
|
||||
raw[field] = np.reshape(np.frombuffer(blob_data, dtype=data_type), data_shape)
|
||||
|
@ -110,7 +120,9 @@ def populate_data(raw, blob, columns, blob_fields, schema):
|
|||
raw[field] = blob_data
|
||||
|
||||
if len(blob_fields) == 1:
|
||||
_render_raw(blob_fields[0], blob_bytes)
|
||||
if len(ordered_columns) == 1:
|
||||
_render_raw(blob_fields[0], blob_bytes)
|
||||
return raw
|
||||
return raw
|
||||
|
||||
def _int_from_bytes(xbytes: bytes) -> int:
|
||||
|
@ -125,6 +137,6 @@ def populate_data(raw, blob, columns, blob_fields, schema):
|
|||
start += 8
|
||||
return blob_bytes[start : start + n_bytes]
|
||||
|
||||
for i, blob_field in enumerate(blob_fields):
|
||||
for i, blob_field in enumerate(ordered_columns):
|
||||
_render_raw(blob_field, _blob_at_position(i))
|
||||
return raw
|
||||
|
|
|
@ -545,3 +545,597 @@ def inputs(vectors, maxlen=50):
|
|||
mask = [1]*length + [0]*(maxlen-length)
|
||||
segment = [0]*maxlen
|
||||
return input_, mask, segment
|
||||
|
||||
def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
|
||||
mindrecord_file_name = "test.mindrecord"
|
||||
data = [{"file_name": "001.jpg", "label": 4,
|
||||
"image1": bytes("image1 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image1 bytes def", encoding='UTF-8'),
|
||||
"source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64),
|
||||
"source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"image3": bytes("image1 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image1 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image1 bytes mno", encoding='UTF-8'),
|
||||
"target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64),
|
||||
"target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64),
|
||||
"target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
|
||||
"target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)},
|
||||
{"file_name": "002.jpg", "label": 5,
|
||||
"image1": bytes("image2 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image2 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image2 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image2 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image2 bytes mno", encoding='UTF-8'),
|
||||
"source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64),
|
||||
"source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64),
|
||||
"target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64),
|
||||
"target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
|
||||
"target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)},
|
||||
{"file_name": "003.jpg", "label": 6,
|
||||
"source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64),
|
||||
"source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64),
|
||||
"target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64),
|
||||
"target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
|
||||
"image1": bytes("image3 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image3 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image3 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image3 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image3 bytes mno", encoding='UTF-8'),
|
||||
"target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)},
|
||||
{"file_name": "004.jpg", "label": 7,
|
||||
"source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64),
|
||||
"source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"image1": bytes("image4 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image4 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image4 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image4 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image4 bytes mno", encoding='UTF-8'),
|
||||
"target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64),
|
||||
"target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64),
|
||||
"target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
|
||||
"target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)},
|
||||
{"file_name": "005.jpg", "label": 8,
|
||||
"source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64),
|
||||
"source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64),
|
||||
"target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64),
|
||||
"image1": bytes("image5 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image5 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image5 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image5 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image5 bytes mno", encoding='UTF-8'),
|
||||
"target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
|
||||
"target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)},
|
||||
{"file_name": "006.jpg", "label": 9,
|
||||
"source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64),
|
||||
"source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64),
|
||||
"image1": bytes("image6 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image6 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image6 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image6 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image6 bytes mno", encoding='UTF-8'),
|
||||
"target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64),
|
||||
"target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
|
||||
"target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)}
|
||||
]
|
||||
|
||||
writer = FileWriter(mindrecord_file_name)
|
||||
schema = {"file_name": {"type": "string"},
|
||||
"image1": {"type": "bytes"},
|
||||
"image2": {"type": "bytes"},
|
||||
"source_sos_ids": {"type": "int64", "shape": [-1]},
|
||||
"source_sos_mask": {"type": "int64", "shape": [-1]},
|
||||
"image3": {"type": "bytes"},
|
||||
"image4": {"type": "bytes"},
|
||||
"image5": {"type": "bytes"},
|
||||
"target_sos_ids": {"type": "int64", "shape": [-1]},
|
||||
"target_sos_mask": {"type": "int64", "shape": [-1]},
|
||||
"target_eos_ids": {"type": "int64", "shape": [-1]},
|
||||
"target_eos_mask": {"type": "int64", "shape": [-1]},
|
||||
"label": {"type": "int32"}}
|
||||
writer.add_schema(schema, "data is so cool")
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
# change data value to list
|
||||
data_value_to_list = []
|
||||
for item in data:
|
||||
new_data = {}
|
||||
new_data['file_name'] = np.asarray(list(bytes(item["file_name"], encoding='utf-8')), dtype=np.uint8)
|
||||
new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32)
|
||||
new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8)
|
||||
new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8)
|
||||
new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8)
|
||||
new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8)
|
||||
new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8)
|
||||
new_data['source_sos_ids'] = item["source_sos_ids"]
|
||||
new_data['source_sos_mask'] = item["source_sos_mask"]
|
||||
new_data['target_sos_ids'] = item["target_sos_ids"]
|
||||
new_data['target_sos_mask'] = item["target_sos_mask"]
|
||||
new_data['target_eos_ids'] = item["target_eos_ids"]
|
||||
new_data['target_eos_mask'] = item["target_eos_mask"]
|
||||
data_value_to_list.append(new_data)
|
||||
|
||||
num_readers = 2
|
||||
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert data_set.get_dataset_size() == 6
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
assert len(item) == 13
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
|
||||
num_readers = 2
|
||||
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
|
||||
columns_list=["source_sos_ids", "source_sos_mask", "target_sos_ids"],
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert data_set.get_dataset_size() == 6
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
assert len(item) == 3
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data[num_iter][field]
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
|
||||
num_readers = 1
|
||||
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
|
||||
columns_list=["image2", "source_sos_mask", "image3", "target_sos_ids"],
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert data_set.get_dataset_size() == 6
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
assert len(item) == 4
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
|
||||
num_readers = 3
|
||||
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
|
||||
columns_list=["target_sos_ids", "image4", "source_sos_ids"],
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert data_set.get_dataset_size() == 6
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
assert len(item) == 3
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
|
||||
num_readers = 3
|
||||
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
|
||||
columns_list=["target_sos_ids", "image5", "image4", "image3", "source_sos_ids"],
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert data_set.get_dataset_size() == 6
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
assert len(item) == 5
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
|
||||
num_readers = 1
|
||||
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
|
||||
columns_list=["target_eos_mask", "image5", "image2", "source_sos_mask", "label"],
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert data_set.get_dataset_size() == 6
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
assert len(item) == 5
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
|
||||
num_readers = 2
|
||||
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
|
||||
columns_list=["label", "target_eos_mask", "image1", "target_eos_ids", "source_sos_mask",
|
||||
"image2", "image4", "image3", "source_sos_ids", "image5", "file_name"],
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert data_set.get_dataset_size() == 6
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
assert len(item) == 11
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
|
||||
os.remove("{}".format(mindrecord_file_name))
|
||||
os.remove("{}.db".format(mindrecord_file_name))
|
||||
|
||||
def test_write_with_multi_bytes_and_MindDataset():
|
||||
mindrecord_file_name = "test.mindrecord"
|
||||
data = [{"file_name": "001.jpg", "label": 43,
|
||||
"image1": bytes("image1 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image1 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image1 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image1 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image1 bytes mno", encoding='UTF-8')},
|
||||
{"file_name": "002.jpg", "label": 91,
|
||||
"image1": bytes("image2 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image2 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image2 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image2 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image2 bytes mno", encoding='UTF-8')},
|
||||
{"file_name": "003.jpg", "label": 61,
|
||||
"image1": bytes("image3 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image3 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image3 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image3 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image3 bytes mno", encoding='UTF-8')},
|
||||
{"file_name": "004.jpg", "label": 29,
|
||||
"image1": bytes("image4 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image4 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image4 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image4 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image4 bytes mno", encoding='UTF-8')},
|
||||
{"file_name": "005.jpg", "label": 78,
|
||||
"image1": bytes("image5 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image5 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image5 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image5 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image5 bytes mno", encoding='UTF-8')},
|
||||
{"file_name": "006.jpg", "label": 37,
|
||||
"image1": bytes("image6 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image6 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image6 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image6 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image6 bytes mno", encoding='UTF-8')}
|
||||
]
|
||||
writer = FileWriter(mindrecord_file_name)
|
||||
schema = {"file_name": {"type": "string"},
|
||||
"image1": {"type": "bytes"},
|
||||
"image2": {"type": "bytes"},
|
||||
"image3": {"type": "bytes"},
|
||||
"label": {"type": "int32"},
|
||||
"image4": {"type": "bytes"},
|
||||
"image5": {"type": "bytes"}}
|
||||
writer.add_schema(schema, "data is so cool")
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
# change data value to list
|
||||
data_value_to_list = []
|
||||
for item in data:
|
||||
new_data = {}
|
||||
new_data['file_name'] = np.asarray(list(bytes(item["file_name"], encoding='utf-8')), dtype=np.uint8)
|
||||
new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32)
|
||||
new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8)
|
||||
new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8)
|
||||
new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8)
|
||||
new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8)
|
||||
new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8)
|
||||
data_value_to_list.append(new_data)
|
||||
|
||||
num_readers = 2
|
||||
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert data_set.get_dataset_size() == 6
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
assert len(item) == 7
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
|
||||
num_readers = 2
|
||||
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
|
||||
columns_list=["image1", "image2", "image5"],
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert data_set.get_dataset_size() == 6
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
assert len(item) == 3
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
|
||||
num_readers = 2
|
||||
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
|
||||
columns_list=["image2", "image4"],
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert data_set.get_dataset_size() == 6
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
assert len(item) == 2
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
|
||||
num_readers = 2
|
||||
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
|
||||
columns_list=["image5", "image2"],
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert data_set.get_dataset_size() == 6
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
assert len(item) == 2
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
|
||||
num_readers = 2
|
||||
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
|
||||
columns_list=["image5", "image2", "label"],
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert data_set.get_dataset_size() == 6
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
assert len(item) == 3
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
|
||||
num_readers = 2
|
||||
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
|
||||
columns_list=["image4", "image5", "image2", "image3", "file_name"],
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert data_set.get_dataset_size() == 6
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
assert len(item) == 5
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
|
||||
os.remove("{}".format(mindrecord_file_name))
|
||||
os.remove("{}.db".format(mindrecord_file_name))
|
||||
|
||||
def test_write_with_multi_array_and_MindDataset():
|
||||
mindrecord_file_name = "test.mindrecord"
|
||||
data = [{"source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64),
|
||||
"source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"source_eos_ids": np.array([13, 14, 15, 16, 17, 18], dtype=np.int64),
|
||||
"source_eos_mask": np.array([19, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
|
||||
"target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64),
|
||||
"target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64),
|
||||
"target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
|
||||
"target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)},
|
||||
{"source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64),
|
||||
"source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"source_eos_ids": np.array([113, 14, 15, 16, 17, 18], dtype=np.int64),
|
||||
"source_eos_mask": np.array([119, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
|
||||
"target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64),
|
||||
"target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64),
|
||||
"target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
|
||||
"target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)},
|
||||
{"source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64),
|
||||
"source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"source_eos_ids": np.array([213, 14, 15, 16, 17, 18], dtype=np.int64),
|
||||
"source_eos_mask": np.array([219, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
|
||||
"target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64),
|
||||
"target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64),
|
||||
"target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
|
||||
"target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)},
|
||||
{"source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64),
|
||||
"source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"source_eos_ids": np.array([313, 14, 15, 16, 17, 18], dtype=np.int64),
|
||||
"source_eos_mask": np.array([319, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
|
||||
"target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64),
|
||||
"target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64),
|
||||
"target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
|
||||
"target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)},
|
||||
{"source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64),
|
||||
"source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"source_eos_ids": np.array([413, 14, 15, 16, 17, 18], dtype=np.int64),
|
||||
"source_eos_mask": np.array([419, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
|
||||
"target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64),
|
||||
"target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64),
|
||||
"target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
|
||||
"target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)},
|
||||
{"source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64),
|
||||
"source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"source_eos_ids": np.array([513, 14, 15, 16, 17, 18], dtype=np.int64),
|
||||
"source_eos_mask": np.array([519, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
|
||||
"target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64),
|
||||
"target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64),
|
||||
"target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
|
||||
"target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)}
|
||||
]
|
||||
writer = FileWriter(mindrecord_file_name)
|
||||
schema = {"source_sos_ids": {"type": "int64", "shape": [-1]},
|
||||
"source_sos_mask": {"type": "int64", "shape": [-1]},
|
||||
"source_eos_ids": {"type": "int64", "shape": [-1]},
|
||||
"source_eos_mask": {"type": "int64", "shape": [-1]},
|
||||
"target_sos_ids": {"type": "int64", "shape": [-1]},
|
||||
"target_sos_mask": {"type": "int64", "shape": [-1]},
|
||||
"target_eos_ids": {"type": "int64", "shape": [-1]},
|
||||
"target_eos_mask": {"type": "int64", "shape": [-1]}}
|
||||
writer.add_schema(schema, "data is so cool")
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
# change data value to list - do none
|
||||
data_value_to_list = []
|
||||
for item in data:
|
||||
new_data = {}
|
||||
new_data['source_sos_ids'] = item["source_sos_ids"]
|
||||
new_data['source_sos_mask'] = item["source_sos_mask"]
|
||||
new_data['source_eos_ids'] = item["source_eos_ids"]
|
||||
new_data['source_eos_mask'] = item["source_eos_mask"]
|
||||
new_data['target_sos_ids'] = item["target_sos_ids"]
|
||||
new_data['target_sos_mask'] = item["target_sos_mask"]
|
||||
new_data['target_eos_ids'] = item["target_eos_ids"]
|
||||
new_data['target_eos_mask'] = item["target_eos_mask"]
|
||||
data_value_to_list.append(new_data)
|
||||
|
||||
num_readers = 2
|
||||
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert data_set.get_dataset_size() == 6
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
assert len(item) == 8
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
|
||||
num_readers = 2
|
||||
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
|
||||
columns_list=["source_eos_ids", "source_eos_mask",
|
||||
"target_sos_ids", "target_sos_mask",
|
||||
"target_eos_ids", "target_eos_mask"],
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert data_set.get_dataset_size() == 6
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
assert len(item) == 6
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
|
||||
num_readers = 2
|
||||
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
|
||||
columns_list=["source_sos_ids",
|
||||
"target_sos_ids",
|
||||
"target_eos_mask"],
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert data_set.get_dataset_size() == 6
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
assert len(item) == 3
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
|
||||
num_readers = 2
|
||||
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
|
||||
columns_list=["target_eos_mask",
|
||||
"source_eos_mask",
|
||||
"source_sos_mask"],
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert data_set.get_dataset_size() == 6
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
assert len(item) == 3
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
|
||||
num_readers = 2
|
||||
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
|
||||
columns_list=["target_eos_ids"],
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert data_set.get_dataset_size() == 6
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
assert len(item) == 1
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
|
||||
num_readers = 1
|
||||
data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
|
||||
columns_list=["target_eos_mask", "target_eos_ids",
|
||||
"target_sos_mask", "target_sos_ids",
|
||||
"source_eos_mask", "source_eos_ids",
|
||||
"source_sos_mask", "source_sos_ids"],
|
||||
num_parallel_workers=num_readers,
|
||||
shuffle=False)
|
||||
assert data_set.get_dataset_size() == 6
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
assert len(item) == 8
|
||||
for field in item:
|
||||
if isinstance(item[field], np.ndarray):
|
||||
assert (item[field] == data_value_to_list[num_iter][field]).all()
|
||||
else:
|
||||
assert item[field] == data_value_to_list[num_iter][field]
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
|
||||
os.remove("{}".format(mindrecord_file_name))
|
||||
os.remove("{}.db".format(mindrecord_file_name))
|
||||
|
|
|
@ -448,3 +448,456 @@ def test_cv_file_writer_no_raw():
|
|||
reader.close()
|
||||
os.remove(NLP_FILE_NAME)
|
||||
os.remove("{}.db".format(NLP_FILE_NAME))
|
||||
|
||||
def test_write_read_process_with_multi_bytes():
|
||||
mindrecord_file_name = "test.mindrecord"
|
||||
data = [{"file_name": "001.jpg", "label": 43,
|
||||
"image1": bytes("image1 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image1 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image1 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image1 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image1 bytes mno", encoding='UTF-8')},
|
||||
{"file_name": "002.jpg", "label": 91,
|
||||
"image1": bytes("image2 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image2 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image2 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image2 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image2 bytes mno", encoding='UTF-8')},
|
||||
{"file_name": "003.jpg", "label": 61,
|
||||
"image1": bytes("image3 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image3 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image3 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image3 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image3 bytes mno", encoding='UTF-8')},
|
||||
{"file_name": "004.jpg", "label": 29,
|
||||
"image1": bytes("image4 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image4 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image4 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image4 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image4 bytes mno", encoding='UTF-8')},
|
||||
{"file_name": "005.jpg", "label": 78,
|
||||
"image1": bytes("image5 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image5 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image5 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image5 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image5 bytes mno", encoding='UTF-8')},
|
||||
{"file_name": "006.jpg", "label": 37,
|
||||
"image1": bytes("image6 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image6 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image6 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image6 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image6 bytes mno", encoding='UTF-8')}
|
||||
]
|
||||
writer = FileWriter(mindrecord_file_name)
|
||||
schema = {"file_name": {"type": "string"},
|
||||
"image1": {"type": "bytes"},
|
||||
"image2": {"type": "bytes"},
|
||||
"image3": {"type": "bytes"},
|
||||
"label": {"type": "int32"},
|
||||
"image4": {"type": "bytes"},
|
||||
"image5": {"type": "bytes"}}
|
||||
writer.add_schema(schema, "data is so cool")
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
reader = FileReader(mindrecord_file_name)
|
||||
count = 0
|
||||
for index, x in enumerate(reader.get_next()):
|
||||
assert len(x) == 7
|
||||
for field in x:
|
||||
if isinstance(x[field], np.ndarray):
|
||||
assert (x[field] == data[count][field]).all()
|
||||
else:
|
||||
assert x[field] == data[count][field]
|
||||
count = count + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
assert count == 6
|
||||
reader.close()
|
||||
|
||||
reader2 = FileReader(file_name=mindrecord_file_name, columns=["image1", "image2", "image5"])
|
||||
count = 0
|
||||
for index, x in enumerate(reader2.get_next()):
|
||||
assert len(x) == 3
|
||||
for field in x:
|
||||
if isinstance(x[field], np.ndarray):
|
||||
assert (x[field] == data[count][field]).all()
|
||||
else:
|
||||
assert x[field] == data[count][field]
|
||||
count = count + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
assert count == 6
|
||||
reader2.close()
|
||||
|
||||
reader3 = FileReader(file_name=mindrecord_file_name, columns=["image2", "image4"])
|
||||
count = 0
|
||||
for index, x in enumerate(reader3.get_next()):
|
||||
assert len(x) == 2
|
||||
for field in x:
|
||||
if isinstance(x[field], np.ndarray):
|
||||
assert (x[field] == data[count][field]).all()
|
||||
else:
|
||||
assert x[field] == data[count][field]
|
||||
count = count + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
assert count == 6
|
||||
reader3.close()
|
||||
|
||||
reader4 = FileReader(file_name=mindrecord_file_name, columns=["image5", "image2"])
|
||||
count = 0
|
||||
for index, x in enumerate(reader4.get_next()):
|
||||
assert len(x) == 2
|
||||
for field in x:
|
||||
if isinstance(x[field], np.ndarray):
|
||||
assert (x[field] == data[count][field]).all()
|
||||
else:
|
||||
assert x[field] == data[count][field]
|
||||
count = count + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
assert count == 6
|
||||
reader4.close()
|
||||
|
||||
reader5 = FileReader(file_name=mindrecord_file_name, columns=["image5", "image2", "label"])
|
||||
count = 0
|
||||
for index, x in enumerate(reader5.get_next()):
|
||||
assert len(x) == 3
|
||||
for field in x:
|
||||
if isinstance(x[field], np.ndarray):
|
||||
assert (x[field] == data[count][field]).all()
|
||||
else:
|
||||
assert x[field] == data[count][field]
|
||||
count = count + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
assert count == 6
|
||||
reader5.close()
|
||||
|
||||
os.remove("{}".format(mindrecord_file_name))
|
||||
os.remove("{}.db".format(mindrecord_file_name))
|
||||
|
||||
def test_write_read_process_with_multi_array():
|
||||
mindrecord_file_name = "test.mindrecord"
|
||||
data = [{"source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64),
|
||||
"source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"source_eos_ids": np.array([13, 14, 15, 16, 17, 18], dtype=np.int64),
|
||||
"source_eos_mask": np.array([19, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
|
||||
"target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64),
|
||||
"target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64),
|
||||
"target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
|
||||
"target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)},
|
||||
{"source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64),
|
||||
"source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"source_eos_ids": np.array([113, 14, 15, 16, 17, 18], dtype=np.int64),
|
||||
"source_eos_mask": np.array([119, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
|
||||
"target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64),
|
||||
"target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64),
|
||||
"target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
|
||||
"target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)},
|
||||
{"source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64),
|
||||
"source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"source_eos_ids": np.array([213, 14, 15, 16, 17, 18], dtype=np.int64),
|
||||
"source_eos_mask": np.array([219, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
|
||||
"target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64),
|
||||
"target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64),
|
||||
"target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
|
||||
"target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)},
|
||||
{"source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64),
|
||||
"source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"source_eos_ids": np.array([313, 14, 15, 16, 17, 18], dtype=np.int64),
|
||||
"source_eos_mask": np.array([319, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
|
||||
"target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64),
|
||||
"target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64),
|
||||
"target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
|
||||
"target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)},
|
||||
{"source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64),
|
||||
"source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"source_eos_ids": np.array([413, 14, 15, 16, 17, 18], dtype=np.int64),
|
||||
"source_eos_mask": np.array([419, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
|
||||
"target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64),
|
||||
"target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64),
|
||||
"target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
|
||||
"target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)},
|
||||
{"source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64),
|
||||
"source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"source_eos_ids": np.array([513, 14, 15, 16, 17, 18], dtype=np.int64),
|
||||
"source_eos_mask": np.array([519, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
|
||||
"target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64),
|
||||
"target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64),
|
||||
"target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
|
||||
"target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)}
|
||||
]
|
||||
writer = FileWriter(mindrecord_file_name)
|
||||
schema = {"source_sos_ids": {"type": "int64", "shape": [-1]},
|
||||
"source_sos_mask": {"type": "int64", "shape": [-1]},
|
||||
"source_eos_ids": {"type": "int64", "shape": [-1]},
|
||||
"source_eos_mask": {"type": "int64", "shape": [-1]},
|
||||
"target_sos_ids": {"type": "int64", "shape": [-1]},
|
||||
"target_sos_mask": {"type": "int64", "shape": [-1]},
|
||||
"target_eos_ids": {"type": "int64", "shape": [-1]},
|
||||
"target_eos_mask": {"type": "int64", "shape": [-1]}}
|
||||
writer.add_schema(schema, "data is so cool")
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
reader = FileReader(mindrecord_file_name)
|
||||
count = 0
|
||||
for index, x in enumerate(reader.get_next()):
|
||||
assert len(x) == 8
|
||||
for field in x:
|
||||
if isinstance(x[field], np.ndarray):
|
||||
assert (x[field] == data[count][field]).all()
|
||||
else:
|
||||
assert x[field] == data[count][field]
|
||||
count = count + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
assert count == 6
|
||||
reader.close()
|
||||
|
||||
reader = FileReader(file_name=mindrecord_file_name, columns=["source_eos_ids", "source_eos_mask",
|
||||
"target_sos_ids", "target_sos_mask",
|
||||
"target_eos_ids", "target_eos_mask"])
|
||||
count = 0
|
||||
for index, x in enumerate(reader.get_next()):
|
||||
assert len(x) == 6
|
||||
for field in x:
|
||||
if isinstance(x[field], np.ndarray):
|
||||
assert (x[field] == data[count][field]).all()
|
||||
else:
|
||||
assert x[field] == data[count][field]
|
||||
count = count + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
assert count == 6
|
||||
reader.close()
|
||||
|
||||
reader = FileReader(file_name=mindrecord_file_name, columns=["source_sos_ids",
|
||||
"target_sos_ids",
|
||||
"target_eos_mask"])
|
||||
count = 0
|
||||
for index, x in enumerate(reader.get_next()):
|
||||
assert len(x) == 3
|
||||
for field in x:
|
||||
if isinstance(x[field], np.ndarray):
|
||||
assert (x[field] == data[count][field]).all()
|
||||
else:
|
||||
assert x[field] == data[count][field]
|
||||
count = count + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
assert count == 6
|
||||
reader.close()
|
||||
|
||||
reader = FileReader(file_name=mindrecord_file_name, columns=["target_eos_mask",
|
||||
"source_eos_mask",
|
||||
"source_sos_mask"])
|
||||
count = 0
|
||||
for index, x in enumerate(reader.get_next()):
|
||||
assert len(x) == 3
|
||||
for field in x:
|
||||
if isinstance(x[field], np.ndarray):
|
||||
assert (x[field] == data[count][field]).all()
|
||||
else:
|
||||
assert x[field] == data[count][field]
|
||||
count = count + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
assert count == 6
|
||||
reader.close()
|
||||
|
||||
reader = FileReader(file_name=mindrecord_file_name, columns=["target_eos_ids"])
|
||||
count = 0
|
||||
for index, x in enumerate(reader.get_next()):
|
||||
assert len(x) == 1
|
||||
for field in x:
|
||||
if isinstance(x[field], np.ndarray):
|
||||
assert (x[field] == data[count][field]).all()
|
||||
else:
|
||||
assert x[field] == data[count][field]
|
||||
count = count + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
assert count == 6
|
||||
reader.close()
|
||||
|
||||
os.remove("{}".format(mindrecord_file_name))
|
||||
os.remove("{}.db".format(mindrecord_file_name))
|
||||
|
||||
def test_write_read_process_with_multi_bytes_and_array():
|
||||
mindrecord_file_name = "test.mindrecord"
|
||||
data = [{"file_name": "001.jpg", "label": 4,
|
||||
"image1": bytes("image1 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image1 bytes def", encoding='UTF-8'),
|
||||
"source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64),
|
||||
"source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"image3": bytes("image1 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image1 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image1 bytes mno", encoding='UTF-8'),
|
||||
"target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64),
|
||||
"target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64),
|
||||
"target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
|
||||
"target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)},
|
||||
{"file_name": "002.jpg", "label": 5,
|
||||
"image1": bytes("image2 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image2 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image2 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image2 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image2 bytes mno", encoding='UTF-8'),
|
||||
"source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64),
|
||||
"source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64),
|
||||
"target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64),
|
||||
"target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
|
||||
"target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)},
|
||||
{"file_name": "003.jpg", "label": 6,
|
||||
"source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64),
|
||||
"source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64),
|
||||
"target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64),
|
||||
"target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
|
||||
"image1": bytes("image3 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image3 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image3 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image3 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image3 bytes mno", encoding='UTF-8'),
|
||||
"target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)},
|
||||
{"file_name": "004.jpg", "label": 7,
|
||||
"source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64),
|
||||
"source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"image1": bytes("image4 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image4 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image4 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image4 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image4 bytes mno", encoding='UTF-8'),
|
||||
"target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64),
|
||||
"target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64),
|
||||
"target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
|
||||
"target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)},
|
||||
{"file_name": "005.jpg", "label": 8,
|
||||
"source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64),
|
||||
"source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64),
|
||||
"target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64),
|
||||
"image1": bytes("image5 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image5 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image5 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image5 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image5 bytes mno", encoding='UTF-8'),
|
||||
"target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
|
||||
"target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)},
|
||||
{"file_name": "006.jpg", "label": 9,
|
||||
"source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64),
|
||||
"source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64),
|
||||
"target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64),
|
||||
"image1": bytes("image6 bytes abc", encoding='UTF-8'),
|
||||
"image2": bytes("image6 bytes def", encoding='UTF-8'),
|
||||
"image3": bytes("image6 bytes ghi", encoding='UTF-8'),
|
||||
"image4": bytes("image6 bytes jkl", encoding='UTF-8'),
|
||||
"image5": bytes("image6 bytes mno", encoding='UTF-8'),
|
||||
"target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64),
|
||||
"target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
|
||||
"target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)}
|
||||
]
|
||||
|
||||
writer = FileWriter(mindrecord_file_name)
|
||||
schema = {"file_name": {"type": "string"},
|
||||
"image1": {"type": "bytes"},
|
||||
"image2": {"type": "bytes"},
|
||||
"source_sos_ids": {"type": "int64", "shape": [-1]},
|
||||
"source_sos_mask": {"type": "int64", "shape": [-1]},
|
||||
"image3": {"type": "bytes"},
|
||||
"image4": {"type": "bytes"},
|
||||
"image5": {"type": "bytes"},
|
||||
"target_sos_ids": {"type": "int64", "shape": [-1]},
|
||||
"target_sos_mask": {"type": "int64", "shape": [-1]},
|
||||
"target_eos_ids": {"type": "int64", "shape": [-1]},
|
||||
"target_eos_mask": {"type": "int64", "shape": [-1]},
|
||||
"label": {"type": "int32"}}
|
||||
writer.add_schema(schema, "data is so cool")
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
reader = FileReader(mindrecord_file_name)
|
||||
count = 0
|
||||
for index, x in enumerate(reader.get_next()):
|
||||
assert len(x) == 13
|
||||
for field in x:
|
||||
if isinstance(x[field], np.ndarray):
|
||||
assert (x[field] == data[count][field]).all()
|
||||
else:
|
||||
assert x[field] == data[count][field]
|
||||
count = count + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
assert count == 6
|
||||
reader.close()
|
||||
|
||||
reader = FileReader(file_name=mindrecord_file_name, columns=["source_sos_ids", "source_sos_mask",
|
||||
"target_sos_ids"])
|
||||
count = 0
|
||||
for index, x in enumerate(reader.get_next()):
|
||||
assert len(x) == 3
|
||||
for field in x:
|
||||
if isinstance(x[field], np.ndarray):
|
||||
assert (x[field] == data[count][field]).all()
|
||||
else:
|
||||
assert x[field] == data[count][field]
|
||||
count = count + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
assert count == 6
|
||||
reader.close()
|
||||
|
||||
reader = FileReader(file_name=mindrecord_file_name, columns=["image2", "source_sos_mask",
|
||||
"image3", "target_sos_ids"])
|
||||
count = 0
|
||||
for index, x in enumerate(reader.get_next()):
|
||||
assert len(x) == 4
|
||||
for field in x:
|
||||
if isinstance(x[field], np.ndarray):
|
||||
assert (x[field] == data[count][field]).all()
|
||||
else:
|
||||
assert x[field] == data[count][field]
|
||||
count = count + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
assert count == 6
|
||||
reader.close()
|
||||
|
||||
reader = FileReader(file_name=mindrecord_file_name, columns=["target_sos_ids", "image4",
|
||||
"source_sos_ids"])
|
||||
count = 0
|
||||
for index, x in enumerate(reader.get_next()):
|
||||
assert len(x) == 3
|
||||
for field in x:
|
||||
if isinstance(x[field], np.ndarray):
|
||||
assert (x[field] == data[count][field]).all()
|
||||
else:
|
||||
assert x[field] == data[count][field]
|
||||
count = count + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
assert count == 6
|
||||
reader.close()
|
||||
|
||||
reader = FileReader(file_name=mindrecord_file_name, columns=["target_sos_ids", "image5",
|
||||
"image4", "image3", "source_sos_ids"])
|
||||
count = 0
|
||||
for index, x in enumerate(reader.get_next()):
|
||||
assert len(x) == 5
|
||||
for field in x:
|
||||
if isinstance(x[field], np.ndarray):
|
||||
assert (x[field] == data[count][field]).all()
|
||||
else:
|
||||
assert x[field] == data[count][field]
|
||||
count = count + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
assert count == 6
|
||||
reader.close()
|
||||
|
||||
reader = FileReader(file_name=mindrecord_file_name, columns=["target_eos_mask", "image5", "image2",
|
||||
"source_sos_mask", "label"])
|
||||
count = 0
|
||||
for index, x in enumerate(reader.get_next()):
|
||||
assert len(x) == 5
|
||||
for field in x:
|
||||
if isinstance(x[field], np.ndarray):
|
||||
assert (x[field] == data[count][field]).all()
|
||||
else:
|
||||
assert x[field] == data[count][field]
|
||||
count = count + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
assert count == 6
|
||||
reader.close()
|
||||
|
||||
os.remove("{}".format(mindrecord_file_name))
|
||||
os.remove("{}.db".format(mindrecord_file_name))
|
||||
|
|
Loading…
Reference in New Issue