From d4d236bccef8d29397514ebac8d64374354c69ac Mon Sep 17 00:00:00 2001 From: jonyguo Date: Wed, 6 May 2020 19:50:04 +0800 Subject: [PATCH] fix: use MindDataset by column_names get data error in some situation --- .../engine/datasetops/source/mindrecord_op.cc | 18 +- .../ccsrc/mindrecord/include/shard_reader.h | 4 + mindspore/ccsrc/mindrecord/io/shard_reader.cc | 70 ++- mindspore/mindrecord/shardutils.py | 20 +- tests/ut/python/dataset/test_minddataset.py | 594 ++++++++++++++++++ .../python/mindrecord/test_mindrecord_base.py | 453 +++++++++++++ 6 files changed, 1149 insertions(+), 10 deletions(-) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc index 171ad49fa7c..9458ca63079 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc @@ -165,12 +165,22 @@ Status MindRecordOp::Init() { Status MindRecordOp::SetColumnsBlob() { columns_blob_ = shard_reader_->get_blob_fields().second; + + // get the exactly blob fields by columns_to_load_ + std::vector columns_blob_exact; + for (auto &blob_field : columns_blob_) { + for (auto &column : columns_to_load_) { + if (column.compare(blob_field) == 0) { + columns_blob_exact.push_back(blob_field); + break; + } + } + } + columns_blob_index_ = std::vector(columns_to_load_.size(), -1); int32_t iBlob = 0; - for (uint32_t i = 0; i < columns_blob_.size(); ++i) { - if (column_name_mapping_.count(columns_blob_[i])) { - columns_blob_index_[column_name_mapping_[columns_blob_[i]]] = iBlob++; - } + for (auto &blob_exact : columns_blob_exact) { + columns_blob_index_[column_name_mapping_[blob_exact]] = iBlob++; } return Status::OK(); } diff --git a/mindspore/ccsrc/mindrecord/include/shard_reader.h b/mindspore/ccsrc/mindrecord/include/shard_reader.h index 3263b2006d7..6b90275cfc6 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_reader.h +++ b/mindspore/ccsrc/mindrecord/include/shard_reader.h @@ -294,6 +294,10 @@ class ShardReader { /// \brief get number of classes int64_t GetNumClasses(const std::string &file_path, const std::string &category_field); + /// \brief get exactly blob fields data by indices + std::vector ExtractBlobFieldBySelectColumns(std::vector &blob_fields_bytes, + std::vector &ordered_selected_columns_index); + protected: uint64_t header_size_; // header size uint64_t page_size_; // page size diff --git a/mindspore/ccsrc/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/mindrecord/io/shard_reader.cc index 804613e40a1..1411e5a7d21 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_reader.cc @@ -790,6 +790,8 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer, n_consumer = kMinConsumerCount; } CheckNlp(); + + // dead code if (nlp_) { selected_columns_ = selected_columns; } else { @@ -801,6 +803,7 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer, } } } + selected_columns_ = selected_columns; if (CheckColumnList(selected_columns_) == FAILED) { MS_LOG(ERROR) << "Illegal column list"; @@ -1060,6 +1063,36 @@ MSRStatus ShardReader::CreateTasks(const std::vector ShardReader::ExtractBlobFieldBySelectColumns( + std::vector &blob_fields_bytes, std::vector &ordered_selected_columns_index) { + std::vector exactly_blob_fields_bytes; + + auto uint64_from_bytes = [&](int64_t pos) { + uint64_t result = 0; + for (uint64_t n = 0; n < kInt64Len; n++) { + result = (result << 8) + blob_fields_bytes[pos + n]; + } + return result; + }; + + // get the exactly blob fields + uint32_t current_index = 0; + uint64_t current_offset = 0; + uint64_t data_len = uint64_from_bytes(current_offset); + while (current_offset < blob_fields_bytes.size()) { + if (std::any_of(ordered_selected_columns_index.begin(), ordered_selected_columns_index.end(), + [¤t_index](uint32_t &index) { return index == current_index; })) { + exactly_blob_fields_bytes.insert(exactly_blob_fields_bytes.end(), blob_fields_bytes.begin() + current_offset, + blob_fields_bytes.begin() + current_offset + kInt64Len + data_len); + } + current_index++; + current_offset += kInt64Len + data_len; + data_len = uint64_from_bytes(current_offset); + } + + return exactly_blob_fields_bytes; +} + TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_id) { // All tasks are done if (task_id >= static_cast(tasks_.Size())) { @@ -1077,6 +1110,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ return std::make_pair(FAILED, std::vector, json>>()); } const std::shared_ptr &page = ret.second; + // Pack image list std::vector images(addr[1] - addr[0]); auto file_offset = header_size_ + page_size_ * (page->get_page_id()) + addr[0]; @@ -1096,10 +1130,42 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ return std::make_pair(FAILED, std::vector, json>>()); } + // extract the exactly blob bytes by selected columns + std::vector images_with_exact_columns; + if (selected_columns_.size() == 0) { + images_with_exact_columns = images; + } else { + auto blob_fields = get_blob_fields(); + + std::vector ordered_selected_columns_index; + uint32_t index = 0; + for (auto &blob_field : blob_fields.second) { + for (auto &field : selected_columns_) { + if (field.compare(blob_field) == 0) { + ordered_selected_columns_index.push_back(index); + break; + } + } + index++; + } + + if (ordered_selected_columns_index.size() != 0) { + // extract the images + if (blob_fields.second.size() == 1) { + if (ordered_selected_columns_index.size() == 1) { + images_with_exact_columns = images; + } + } else { + images_with_exact_columns = ExtractBlobFieldBySelectColumns(images, ordered_selected_columns_index); + } + } + } + // Deliver batch data to output map std::vector, json>> batch; if (nlp_) { - json blob_fields = json::from_msgpack(images); + // dead code + json blob_fields = json::from_msgpack(images_with_exact_columns); json merge; if (selected_columns_.size() > 0) { @@ -1117,7 +1183,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ } batch.emplace_back(std::vector{}, std::move(merge)); } else { - batch.emplace_back(std::move(images), std::move(std::get<2>(task))); + batch.emplace_back(std::move(images_with_exact_columns), std::move(std::get<2>(task))); } return std::make_pair(SUCCESS, std::move(batch)); } diff --git a/mindspore/mindrecord/shardutils.py b/mindspore/mindrecord/shardutils.py index 2f57505800a..a71dd228f64 100644 --- a/mindspore/mindrecord/shardutils.py +++ b/mindspore/mindrecord/shardutils.py @@ -92,15 +92,25 @@ def populate_data(raw, blob, columns, blob_fields, schema): if raw: # remove dummy fileds raw = {k: v for k, v in raw.items() if k in schema} + else: + raw = {} if not blob_fields: return raw + + # Get the order preserving sequence of columns in blob + ordered_columns = [] + if columns: + for blob_field in blob_fields: + if blob_field in columns: + ordered_columns.append(blob_field) + else: + ordered_columns = blob_fields + blob_bytes = bytes(blob) def _render_raw(field, blob_data): data_type = schema[field]['type'] data_shape = schema[field]['shape'] if 'shape' in schema[field] else [] - if columns and field not in columns: - return if data_shape: try: raw[field] = np.reshape(np.frombuffer(blob_data, dtype=data_type), data_shape) @@ -110,7 +120,9 @@ def populate_data(raw, blob, columns, blob_fields, schema): raw[field] = blob_data if len(blob_fields) == 1: - _render_raw(blob_fields[0], blob_bytes) + if len(ordered_columns) == 1: + _render_raw(blob_fields[0], blob_bytes) + return raw return raw def _int_from_bytes(xbytes: bytes) -> int: @@ -125,6 +137,6 @@ def populate_data(raw, blob, columns, blob_fields, schema): start += 8 return blob_bytes[start : start + n_bytes] - for i, blob_field in enumerate(blob_fields): + for i, blob_field in enumerate(ordered_columns): _render_raw(blob_field, _blob_at_position(i)) return raw diff --git a/tests/ut/python/dataset/test_minddataset.py b/tests/ut/python/dataset/test_minddataset.py index 460a728b5c5..ba0c86dc8a0 100644 --- a/tests/ut/python/dataset/test_minddataset.py +++ b/tests/ut/python/dataset/test_minddataset.py @@ -545,3 +545,597 @@ def inputs(vectors, maxlen=50): mask = [1]*length + [0]*(maxlen-length) segment = [0]*maxlen return input_, mask, segment + +def test_write_with_multi_bytes_and_array_and_read_by_MindDataset(): + mindrecord_file_name = "test.mindrecord" + data = [{"file_name": "001.jpg", "label": 4, + "image1": bytes("image1 bytes abc", encoding='UTF-8'), + "image2": bytes("image1 bytes def", encoding='UTF-8'), + "source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "image3": bytes("image1 bytes ghi", encoding='UTF-8'), + "image4": bytes("image1 bytes jkl", encoding='UTF-8'), + "image5": bytes("image1 bytes mno", encoding='UTF-8'), + "target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)}, + {"file_name": "002.jpg", "label": 5, + "image1": bytes("image2 bytes abc", encoding='UTF-8'), + "image2": bytes("image2 bytes def", encoding='UTF-8'), + "image3": bytes("image2 bytes ghi", encoding='UTF-8'), + "image4": bytes("image2 bytes jkl", encoding='UTF-8'), + "image5": bytes("image2 bytes mno", encoding='UTF-8'), + "source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)}, + {"file_name": "003.jpg", "label": 6, + "source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "image1": bytes("image3 bytes abc", encoding='UTF-8'), + "image2": bytes("image3 bytes def", encoding='UTF-8'), + "image3": bytes("image3 bytes ghi", encoding='UTF-8'), + "image4": bytes("image3 bytes jkl", encoding='UTF-8'), + "image5": bytes("image3 bytes mno", encoding='UTF-8'), + "target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)}, + {"file_name": "004.jpg", "label": 7, + "source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "image1": bytes("image4 bytes abc", encoding='UTF-8'), + "image2": bytes("image4 bytes def", encoding='UTF-8'), + "image3": bytes("image4 bytes ghi", encoding='UTF-8'), + "image4": bytes("image4 bytes jkl", encoding='UTF-8'), + "image5": bytes("image4 bytes mno", encoding='UTF-8'), + "target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)}, + {"file_name": "005.jpg", "label": 8, + "source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64), + "image1": bytes("image5 bytes abc", encoding='UTF-8'), + "image2": bytes("image5 bytes def", encoding='UTF-8'), + "image3": bytes("image5 bytes ghi", encoding='UTF-8'), + "image4": bytes("image5 bytes jkl", encoding='UTF-8'), + "image5": bytes("image5 bytes mno", encoding='UTF-8'), + "target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)}, + {"file_name": "006.jpg", "label": 9, + "source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64), + "image1": bytes("image6 bytes abc", encoding='UTF-8'), + "image2": bytes("image6 bytes def", encoding='UTF-8'), + "image3": bytes("image6 bytes ghi", encoding='UTF-8'), + "image4": bytes("image6 bytes jkl", encoding='UTF-8'), + "image5": bytes("image6 bytes mno", encoding='UTF-8'), + "target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)} + ] + + writer = FileWriter(mindrecord_file_name) + schema = {"file_name": {"type": "string"}, + "image1": {"type": "bytes"}, + "image2": {"type": "bytes"}, + "source_sos_ids": {"type": "int64", "shape": [-1]}, + "source_sos_mask": {"type": "int64", "shape": [-1]}, + "image3": {"type": "bytes"}, + "image4": {"type": "bytes"}, + "image5": {"type": "bytes"}, + "target_sos_ids": {"type": "int64", "shape": [-1]}, + "target_sos_mask": {"type": "int64", "shape": [-1]}, + "target_eos_ids": {"type": "int64", "shape": [-1]}, + "target_eos_mask": {"type": "int64", "shape": [-1]}, + "label": {"type": "int32"}} + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + # change data value to list + data_value_to_list = [] + for item in data: + new_data = {} + new_data['file_name'] = np.asarray(list(bytes(item["file_name"], encoding='utf-8')), dtype=np.uint8) + new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32) + new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8) + new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8) + new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8) + new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8) + new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8) + new_data['source_sos_ids'] = item["source_sos_ids"] + new_data['source_sos_mask'] = item["source_sos_mask"] + new_data['target_sos_ids'] = item["target_sos_ids"] + new_data['target_sos_mask'] = item["target_sos_mask"] + new_data['target_eos_ids'] = item["target_eos_ids"] + new_data['target_eos_mask'] = item["target_eos_mask"] + data_value_to_list.append(new_data) + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 13 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["source_sos_ids", "source_sos_mask", "target_sos_ids"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 3 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == data[num_iter][field]).all() + else: + assert item[field] == data[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 1 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["image2", "source_sos_mask", "image3", "target_sos_ids"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 4 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 3 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["target_sos_ids", "image4", "source_sos_ids"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 3 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 3 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["target_sos_ids", "image5", "image4", "image3", "source_sos_ids"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 5 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 1 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["target_eos_mask", "image5", "image2", "source_sos_mask", "label"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 5 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["label", "target_eos_mask", "image1", "target_eos_ids", "source_sos_mask", + "image2", "image4", "image3", "source_sos_ids", "image5", "file_name"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 11 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name)) + +def test_write_with_multi_bytes_and_MindDataset(): + mindrecord_file_name = "test.mindrecord" + data = [{"file_name": "001.jpg", "label": 43, + "image1": bytes("image1 bytes abc", encoding='UTF-8'), + "image2": bytes("image1 bytes def", encoding='UTF-8'), + "image3": bytes("image1 bytes ghi", encoding='UTF-8'), + "image4": bytes("image1 bytes jkl", encoding='UTF-8'), + "image5": bytes("image1 bytes mno", encoding='UTF-8')}, + {"file_name": "002.jpg", "label": 91, + "image1": bytes("image2 bytes abc", encoding='UTF-8'), + "image2": bytes("image2 bytes def", encoding='UTF-8'), + "image3": bytes("image2 bytes ghi", encoding='UTF-8'), + "image4": bytes("image2 bytes jkl", encoding='UTF-8'), + "image5": bytes("image2 bytes mno", encoding='UTF-8')}, + {"file_name": "003.jpg", "label": 61, + "image1": bytes("image3 bytes abc", encoding='UTF-8'), + "image2": bytes("image3 bytes def", encoding='UTF-8'), + "image3": bytes("image3 bytes ghi", encoding='UTF-8'), + "image4": bytes("image3 bytes jkl", encoding='UTF-8'), + "image5": bytes("image3 bytes mno", encoding='UTF-8')}, + {"file_name": "004.jpg", "label": 29, + "image1": bytes("image4 bytes abc", encoding='UTF-8'), + "image2": bytes("image4 bytes def", encoding='UTF-8'), + "image3": bytes("image4 bytes ghi", encoding='UTF-8'), + "image4": bytes("image4 bytes jkl", encoding='UTF-8'), + "image5": bytes("image4 bytes mno", encoding='UTF-8')}, + {"file_name": "005.jpg", "label": 78, + "image1": bytes("image5 bytes abc", encoding='UTF-8'), + "image2": bytes("image5 bytes def", encoding='UTF-8'), + "image3": bytes("image5 bytes ghi", encoding='UTF-8'), + "image4": bytes("image5 bytes jkl", encoding='UTF-8'), + "image5": bytes("image5 bytes mno", encoding='UTF-8')}, + {"file_name": "006.jpg", "label": 37, + "image1": bytes("image6 bytes abc", encoding='UTF-8'), + "image2": bytes("image6 bytes def", encoding='UTF-8'), + "image3": bytes("image6 bytes ghi", encoding='UTF-8'), + "image4": bytes("image6 bytes jkl", encoding='UTF-8'), + "image5": bytes("image6 bytes mno", encoding='UTF-8')} + ] + writer = FileWriter(mindrecord_file_name) + schema = {"file_name": {"type": "string"}, + "image1": {"type": "bytes"}, + "image2": {"type": "bytes"}, + "image3": {"type": "bytes"}, + "label": {"type": "int32"}, + "image4": {"type": "bytes"}, + "image5": {"type": "bytes"}} + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + # change data value to list + data_value_to_list = [] + for item in data: + new_data = {} + new_data['file_name'] = np.asarray(list(bytes(item["file_name"], encoding='utf-8')), dtype=np.uint8) + new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32) + new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8) + new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8) + new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8) + new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8) + new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8) + data_value_to_list.append(new_data) + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 7 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["image1", "image2", "image5"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 3 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["image2", "image4"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 2 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["image5", "image2"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 2 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["image5", "image2", "label"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 3 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["image4", "image5", "image2", "image3", "file_name"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 5 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name)) + +def test_write_with_multi_array_and_MindDataset(): + mindrecord_file_name = "test.mindrecord" + data = [{"source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "source_eos_ids": np.array([13, 14, 15, 16, 17, 18], dtype=np.int64), + "source_eos_mask": np.array([19, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), + "target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)}, + {"source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "source_eos_ids": np.array([113, 14, 15, 16, 17, 18], dtype=np.int64), + "source_eos_mask": np.array([119, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), + "target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)}, + {"source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "source_eos_ids": np.array([213, 14, 15, 16, 17, 18], dtype=np.int64), + "source_eos_mask": np.array([219, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), + "target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)}, + {"source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "source_eos_ids": np.array([313, 14, 15, 16, 17, 18], dtype=np.int64), + "source_eos_mask": np.array([319, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), + "target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)}, + {"source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "source_eos_ids": np.array([413, 14, 15, 16, 17, 18], dtype=np.int64), + "source_eos_mask": np.array([419, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), + "target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)}, + {"source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "source_eos_ids": np.array([513, 14, 15, 16, 17, 18], dtype=np.int64), + "source_eos_mask": np.array([519, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), + "target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)} + ] + writer = FileWriter(mindrecord_file_name) + schema = {"source_sos_ids": {"type": "int64", "shape": [-1]}, + "source_sos_mask": {"type": "int64", "shape": [-1]}, + "source_eos_ids": {"type": "int64", "shape": [-1]}, + "source_eos_mask": {"type": "int64", "shape": [-1]}, + "target_sos_ids": {"type": "int64", "shape": [-1]}, + "target_sos_mask": {"type": "int64", "shape": [-1]}, + "target_eos_ids": {"type": "int64", "shape": [-1]}, + "target_eos_mask": {"type": "int64", "shape": [-1]}} + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + # change data value to list - do none + data_value_to_list = [] + for item in data: + new_data = {} + new_data['source_sos_ids'] = item["source_sos_ids"] + new_data['source_sos_mask'] = item["source_sos_mask"] + new_data['source_eos_ids'] = item["source_eos_ids"] + new_data['source_eos_mask'] = item["source_eos_mask"] + new_data['target_sos_ids'] = item["target_sos_ids"] + new_data['target_sos_mask'] = item["target_sos_mask"] + new_data['target_eos_ids'] = item["target_eos_ids"] + new_data['target_eos_mask'] = item["target_eos_mask"] + data_value_to_list.append(new_data) + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 8 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["source_eos_ids", "source_eos_mask", + "target_sos_ids", "target_sos_mask", + "target_eos_ids", "target_eos_mask"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 6 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["source_sos_ids", + "target_sos_ids", + "target_eos_mask"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 3 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["target_eos_mask", + "source_eos_mask", + "source_sos_mask"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 3 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 2 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["target_eos_ids"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 1 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + num_readers = 1 + data_set = ds.MindDataset(dataset_file=mindrecord_file_name, + columns_list=["target_eos_mask", "target_eos_ids", + "target_sos_mask", "target_sos_ids", + "source_eos_mask", "source_eos_ids", + "source_sos_mask", "source_sos_ids"], + num_parallel_workers=num_readers, + shuffle=False) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + assert len(item) == 8 + for field in item: + if isinstance(item[field], np.ndarray): + assert (item[field] == data_value_to_list[num_iter][field]).all() + else: + assert item[field] == data_value_to_list[num_iter][field] + num_iter += 1 + assert num_iter == 6 + + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name)) diff --git a/tests/ut/python/mindrecord/test_mindrecord_base.py b/tests/ut/python/mindrecord/test_mindrecord_base.py index 93e5c609f74..778ebccf84e 100644 --- a/tests/ut/python/mindrecord/test_mindrecord_base.py +++ b/tests/ut/python/mindrecord/test_mindrecord_base.py @@ -448,3 +448,456 @@ def test_cv_file_writer_no_raw(): reader.close() os.remove(NLP_FILE_NAME) os.remove("{}.db".format(NLP_FILE_NAME)) + +def test_write_read_process_with_multi_bytes(): + mindrecord_file_name = "test.mindrecord" + data = [{"file_name": "001.jpg", "label": 43, + "image1": bytes("image1 bytes abc", encoding='UTF-8'), + "image2": bytes("image1 bytes def", encoding='UTF-8'), + "image3": bytes("image1 bytes ghi", encoding='UTF-8'), + "image4": bytes("image1 bytes jkl", encoding='UTF-8'), + "image5": bytes("image1 bytes mno", encoding='UTF-8')}, + {"file_name": "002.jpg", "label": 91, + "image1": bytes("image2 bytes abc", encoding='UTF-8'), + "image2": bytes("image2 bytes def", encoding='UTF-8'), + "image3": bytes("image2 bytes ghi", encoding='UTF-8'), + "image4": bytes("image2 bytes jkl", encoding='UTF-8'), + "image5": bytes("image2 bytes mno", encoding='UTF-8')}, + {"file_name": "003.jpg", "label": 61, + "image1": bytes("image3 bytes abc", encoding='UTF-8'), + "image2": bytes("image3 bytes def", encoding='UTF-8'), + "image3": bytes("image3 bytes ghi", encoding='UTF-8'), + "image4": bytes("image3 bytes jkl", encoding='UTF-8'), + "image5": bytes("image3 bytes mno", encoding='UTF-8')}, + {"file_name": "004.jpg", "label": 29, + "image1": bytes("image4 bytes abc", encoding='UTF-8'), + "image2": bytes("image4 bytes def", encoding='UTF-8'), + "image3": bytes("image4 bytes ghi", encoding='UTF-8'), + "image4": bytes("image4 bytes jkl", encoding='UTF-8'), + "image5": bytes("image4 bytes mno", encoding='UTF-8')}, + {"file_name": "005.jpg", "label": 78, + "image1": bytes("image5 bytes abc", encoding='UTF-8'), + "image2": bytes("image5 bytes def", encoding='UTF-8'), + "image3": bytes("image5 bytes ghi", encoding='UTF-8'), + "image4": bytes("image5 bytes jkl", encoding='UTF-8'), + "image5": bytes("image5 bytes mno", encoding='UTF-8')}, + {"file_name": "006.jpg", "label": 37, + "image1": bytes("image6 bytes abc", encoding='UTF-8'), + "image2": bytes("image6 bytes def", encoding='UTF-8'), + "image3": bytes("image6 bytes ghi", encoding='UTF-8'), + "image4": bytes("image6 bytes jkl", encoding='UTF-8'), + "image5": bytes("image6 bytes mno", encoding='UTF-8')} + ] + writer = FileWriter(mindrecord_file_name) + schema = {"file_name": {"type": "string"}, + "image1": {"type": "bytes"}, + "image2": {"type": "bytes"}, + "image3": {"type": "bytes"}, + "label": {"type": "int32"}, + "image4": {"type": "bytes"}, + "image5": {"type": "bytes"}} + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + reader = FileReader(mindrecord_file_name) + count = 0 + for index, x in enumerate(reader.get_next()): + assert len(x) == 7 + for field in x: + if isinstance(x[field], np.ndarray): + assert (x[field] == data[count][field]).all() + else: + assert x[field] == data[count][field] + count = count + 1 + logger.info("#item{}: {}".format(index, x)) + assert count == 6 + reader.close() + + reader2 = FileReader(file_name=mindrecord_file_name, columns=["image1", "image2", "image5"]) + count = 0 + for index, x in enumerate(reader2.get_next()): + assert len(x) == 3 + for field in x: + if isinstance(x[field], np.ndarray): + assert (x[field] == data[count][field]).all() + else: + assert x[field] == data[count][field] + count = count + 1 + logger.info("#item{}: {}".format(index, x)) + assert count == 6 + reader2.close() + + reader3 = FileReader(file_name=mindrecord_file_name, columns=["image2", "image4"]) + count = 0 + for index, x in enumerate(reader3.get_next()): + assert len(x) == 2 + for field in x: + if isinstance(x[field], np.ndarray): + assert (x[field] == data[count][field]).all() + else: + assert x[field] == data[count][field] + count = count + 1 + logger.info("#item{}: {}".format(index, x)) + assert count == 6 + reader3.close() + + reader4 = FileReader(file_name=mindrecord_file_name, columns=["image5", "image2"]) + count = 0 + for index, x in enumerate(reader4.get_next()): + assert len(x) == 2 + for field in x: + if isinstance(x[field], np.ndarray): + assert (x[field] == data[count][field]).all() + else: + assert x[field] == data[count][field] + count = count + 1 + logger.info("#item{}: {}".format(index, x)) + assert count == 6 + reader4.close() + + reader5 = FileReader(file_name=mindrecord_file_name, columns=["image5", "image2", "label"]) + count = 0 + for index, x in enumerate(reader5.get_next()): + assert len(x) == 3 + for field in x: + if isinstance(x[field], np.ndarray): + assert (x[field] == data[count][field]).all() + else: + assert x[field] == data[count][field] + count = count + 1 + logger.info("#item{}: {}".format(index, x)) + assert count == 6 + reader5.close() + + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name)) + +def test_write_read_process_with_multi_array(): + mindrecord_file_name = "test.mindrecord" + data = [{"source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "source_eos_ids": np.array([13, 14, 15, 16, 17, 18], dtype=np.int64), + "source_eos_mask": np.array([19, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), + "target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)}, + {"source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "source_eos_ids": np.array([113, 14, 15, 16, 17, 18], dtype=np.int64), + "source_eos_mask": np.array([119, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), + "target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)}, + {"source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "source_eos_ids": np.array([213, 14, 15, 16, 17, 18], dtype=np.int64), + "source_eos_mask": np.array([219, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), + "target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)}, + {"source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "source_eos_ids": np.array([313, 14, 15, 16, 17, 18], dtype=np.int64), + "source_eos_mask": np.array([319, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), + "target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)}, + {"source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "source_eos_ids": np.array([413, 14, 15, 16, 17, 18], dtype=np.int64), + "source_eos_mask": np.array([419, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), + "target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)}, + {"source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "source_eos_ids": np.array([513, 14, 15, 16, 17, 18], dtype=np.int64), + "source_eos_mask": np.array([519, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64), + "target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)} + ] + writer = FileWriter(mindrecord_file_name) + schema = {"source_sos_ids": {"type": "int64", "shape": [-1]}, + "source_sos_mask": {"type": "int64", "shape": [-1]}, + "source_eos_ids": {"type": "int64", "shape": [-1]}, + "source_eos_mask": {"type": "int64", "shape": [-1]}, + "target_sos_ids": {"type": "int64", "shape": [-1]}, + "target_sos_mask": {"type": "int64", "shape": [-1]}, + "target_eos_ids": {"type": "int64", "shape": [-1]}, + "target_eos_mask": {"type": "int64", "shape": [-1]}} + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + reader = FileReader(mindrecord_file_name) + count = 0 + for index, x in enumerate(reader.get_next()): + assert len(x) == 8 + for field in x: + if isinstance(x[field], np.ndarray): + assert (x[field] == data[count][field]).all() + else: + assert x[field] == data[count][field] + count = count + 1 + logger.info("#item{}: {}".format(index, x)) + assert count == 6 + reader.close() + + reader = FileReader(file_name=mindrecord_file_name, columns=["source_eos_ids", "source_eos_mask", + "target_sos_ids", "target_sos_mask", + "target_eos_ids", "target_eos_mask"]) + count = 0 + for index, x in enumerate(reader.get_next()): + assert len(x) == 6 + for field in x: + if isinstance(x[field], np.ndarray): + assert (x[field] == data[count][field]).all() + else: + assert x[field] == data[count][field] + count = count + 1 + logger.info("#item{}: {}".format(index, x)) + assert count == 6 + reader.close() + + reader = FileReader(file_name=mindrecord_file_name, columns=["source_sos_ids", + "target_sos_ids", + "target_eos_mask"]) + count = 0 + for index, x in enumerate(reader.get_next()): + assert len(x) == 3 + for field in x: + if isinstance(x[field], np.ndarray): + assert (x[field] == data[count][field]).all() + else: + assert x[field] == data[count][field] + count = count + 1 + logger.info("#item{}: {}".format(index, x)) + assert count == 6 + reader.close() + + reader = FileReader(file_name=mindrecord_file_name, columns=["target_eos_mask", + "source_eos_mask", + "source_sos_mask"]) + count = 0 + for index, x in enumerate(reader.get_next()): + assert len(x) == 3 + for field in x: + if isinstance(x[field], np.ndarray): + assert (x[field] == data[count][field]).all() + else: + assert x[field] == data[count][field] + count = count + 1 + logger.info("#item{}: {}".format(index, x)) + assert count == 6 + reader.close() + + reader = FileReader(file_name=mindrecord_file_name, columns=["target_eos_ids"]) + count = 0 + for index, x in enumerate(reader.get_next()): + assert len(x) == 1 + for field in x: + if isinstance(x[field], np.ndarray): + assert (x[field] == data[count][field]).all() + else: + assert x[field] == data[count][field] + count = count + 1 + logger.info("#item{}: {}".format(index, x)) + assert count == 6 + reader.close() + + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name)) + +def test_write_read_process_with_multi_bytes_and_array(): + mindrecord_file_name = "test.mindrecord" + data = [{"file_name": "001.jpg", "label": 4, + "image1": bytes("image1 bytes abc", encoding='UTF-8'), + "image2": bytes("image1 bytes def", encoding='UTF-8'), + "source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "image3": bytes("image1 bytes ghi", encoding='UTF-8'), + "image4": bytes("image1 bytes jkl", encoding='UTF-8'), + "image5": bytes("image1 bytes mno", encoding='UTF-8'), + "target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)}, + {"file_name": "002.jpg", "label": 5, + "image1": bytes("image2 bytes abc", encoding='UTF-8'), + "image2": bytes("image2 bytes def", encoding='UTF-8'), + "image3": bytes("image2 bytes ghi", encoding='UTF-8'), + "image4": bytes("image2 bytes jkl", encoding='UTF-8'), + "image5": bytes("image2 bytes mno", encoding='UTF-8'), + "source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)}, + {"file_name": "003.jpg", "label": 6, + "source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "image1": bytes("image3 bytes abc", encoding='UTF-8'), + "image2": bytes("image3 bytes def", encoding='UTF-8'), + "image3": bytes("image3 bytes ghi", encoding='UTF-8'), + "image4": bytes("image3 bytes jkl", encoding='UTF-8'), + "image5": bytes("image3 bytes mno", encoding='UTF-8'), + "target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)}, + {"file_name": "004.jpg", "label": 7, + "source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "image1": bytes("image4 bytes abc", encoding='UTF-8'), + "image2": bytes("image4 bytes def", encoding='UTF-8'), + "image3": bytes("image4 bytes ghi", encoding='UTF-8'), + "image4": bytes("image4 bytes jkl", encoding='UTF-8'), + "image5": bytes("image4 bytes mno", encoding='UTF-8'), + "target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)}, + {"file_name": "005.jpg", "label": 8, + "source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64), + "target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64), + "image1": bytes("image5 bytes abc", encoding='UTF-8'), + "image2": bytes("image5 bytes def", encoding='UTF-8'), + "image3": bytes("image5 bytes ghi", encoding='UTF-8'), + "image4": bytes("image5 bytes jkl", encoding='UTF-8'), + "image5": bytes("image5 bytes mno", encoding='UTF-8'), + "target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)}, + {"file_name": "006.jpg", "label": 9, + "source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64), + "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64), + "target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64), + "image1": bytes("image6 bytes abc", encoding='UTF-8'), + "image2": bytes("image6 bytes def", encoding='UTF-8'), + "image3": bytes("image6 bytes ghi", encoding='UTF-8'), + "image4": bytes("image6 bytes jkl", encoding='UTF-8'), + "image5": bytes("image6 bytes mno", encoding='UTF-8'), + "target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64), + "target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64), + "target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)} + ] + + writer = FileWriter(mindrecord_file_name) + schema = {"file_name": {"type": "string"}, + "image1": {"type": "bytes"}, + "image2": {"type": "bytes"}, + "source_sos_ids": {"type": "int64", "shape": [-1]}, + "source_sos_mask": {"type": "int64", "shape": [-1]}, + "image3": {"type": "bytes"}, + "image4": {"type": "bytes"}, + "image5": {"type": "bytes"}, + "target_sos_ids": {"type": "int64", "shape": [-1]}, + "target_sos_mask": {"type": "int64", "shape": [-1]}, + "target_eos_ids": {"type": "int64", "shape": [-1]}, + "target_eos_mask": {"type": "int64", "shape": [-1]}, + "label": {"type": "int32"}} + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + reader = FileReader(mindrecord_file_name) + count = 0 + for index, x in enumerate(reader.get_next()): + assert len(x) == 13 + for field in x: + if isinstance(x[field], np.ndarray): + assert (x[field] == data[count][field]).all() + else: + assert x[field] == data[count][field] + count = count + 1 + logger.info("#item{}: {}".format(index, x)) + assert count == 6 + reader.close() + + reader = FileReader(file_name=mindrecord_file_name, columns=["source_sos_ids", "source_sos_mask", + "target_sos_ids"]) + count = 0 + for index, x in enumerate(reader.get_next()): + assert len(x) == 3 + for field in x: + if isinstance(x[field], np.ndarray): + assert (x[field] == data[count][field]).all() + else: + assert x[field] == data[count][field] + count = count + 1 + logger.info("#item{}: {}".format(index, x)) + assert count == 6 + reader.close() + + reader = FileReader(file_name=mindrecord_file_name, columns=["image2", "source_sos_mask", + "image3", "target_sos_ids"]) + count = 0 + for index, x in enumerate(reader.get_next()): + assert len(x) == 4 + for field in x: + if isinstance(x[field], np.ndarray): + assert (x[field] == data[count][field]).all() + else: + assert x[field] == data[count][field] + count = count + 1 + logger.info("#item{}: {}".format(index, x)) + assert count == 6 + reader.close() + + reader = FileReader(file_name=mindrecord_file_name, columns=["target_sos_ids", "image4", + "source_sos_ids"]) + count = 0 + for index, x in enumerate(reader.get_next()): + assert len(x) == 3 + for field in x: + if isinstance(x[field], np.ndarray): + assert (x[field] == data[count][field]).all() + else: + assert x[field] == data[count][field] + count = count + 1 + logger.info("#item{}: {}".format(index, x)) + assert count == 6 + reader.close() + + reader = FileReader(file_name=mindrecord_file_name, columns=["target_sos_ids", "image5", + "image4", "image3", "source_sos_ids"]) + count = 0 + for index, x in enumerate(reader.get_next()): + assert len(x) == 5 + for field in x: + if isinstance(x[field], np.ndarray): + assert (x[field] == data[count][field]).all() + else: + assert x[field] == data[count][field] + count = count + 1 + logger.info("#item{}: {}".format(index, x)) + assert count == 6 + reader.close() + + reader = FileReader(file_name=mindrecord_file_name, columns=["target_eos_mask", "image5", "image2", + "source_sos_mask", "label"]) + count = 0 + for index, x in enumerate(reader.get_next()): + assert len(x) == 5 + for field in x: + if isinstance(x[field], np.ndarray): + assert (x[field] == data[count][field]).all() + else: + assert x[field] == data[count][field] + count = count + 1 + logger.info("#item{}: {}".format(index, x)) + assert count == 6 + reader.close() + + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name))