diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings.cc index 99c94acd2e4..8060006161e 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings.cc @@ -76,6 +76,12 @@ PYBIND_REGISTER( THROW_IF_ERROR(de.GetOutputTypes(&out)); return out; }) + .def("GetDataInfo", + [](DEPipeline &de) { + py::list types, shapes; + THROW_IF_ERROR(de.GetDataInfo(&types, &shapes)); + return py::make_tuple(types, shapes); + }) .def("GetDatasetSize", &DEPipeline::GetDatasetSize) .def("GetBatchSize", &DEPipeline::GetBatchSize) .def("GetNumClasses", &DEPipeline::GetNumClasses) diff --git a/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc index f60912a01d4..a227197fa17 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc @@ -241,6 +241,30 @@ Status DEPipeline::GetNextAsList(py::list *output) { return Status::OK(); } +Status DEPipeline::GetDataInfo(py::list *types, py::list *shapes) { + Status s; + DATA_INFO data_info; + // tree_.root() must be DeviceQueueOp + DeviceQueueOp *op = dynamic_cast(tree_->root().get()); + if (op == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "GetDataInfo only supported by DeviceQueueOp"); + } + { + py::gil_scoped_release gil_release; + s = op->GetDataInfo(&data_info); + } + RETURN_IF_NOT_OK(s); + for (auto el : data_info) { + types->append(el.first.AsNumpyType()); + py::list shape; + for (auto dim : el.second.AsVector()) { + shape.append(dim); + } + shapes->append(shape); + } + return Status::OK(); +} + Status DEPipeline::GetOutputShapes(py::list *output) { std::vector shapes; Status s; @@ -1070,6 +1094,8 @@ Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptrSetSendEpochEnd(ToBool(value)); } else if (key == "total_batch") { (void)builder->SetTotalBatch(ToInt(value)); + } else if (key == "create_data_info_queue") { + (void)builder->SetCreateDataInfoQueue(ToBool(value)); } } } diff --git a/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.h b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.h index afe93b1a990..c96e4e79024 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.h +++ b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.h @@ -111,6 +111,8 @@ class DEPipeline { Status GetOutputTypes(py::list *output); + Status GetDataInfo(py::list *types, py::list *shapes); + Status SaveDataset(const std::vector &file_names, const std::string &file_type); int GetDatasetSize() const; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc index dd3b5280022..b8764048b1e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc @@ -33,7 +33,7 @@ namespace mindspore { namespace dataset { DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size, - bool send_epoch_end, int32_t total_batch) + bool send_epoch_end, int32_t total_batch, bool create_data_info_queue) : PipelineOp(1), channel_name_(channel_name), device_type_(device_type), @@ -41,7 +41,8 @@ DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, i prefetch_size_(prefetch_size), send_epoch_end_(send_epoch_end), stop_send_(false), - total_batch_(total_batch) { + total_batch_(total_batch), + create_data_info_queue_(create_data_info_queue) { #ifdef ENABLE_TDTQUE ascend_keep_waiting_ = true; #endif @@ -87,6 +88,10 @@ Status DeviceQueueOp::operator()() { if (device_type_ == DeviceType::Ascend) { #ifdef ENABLE_TDTQUE + if (create_data_info_queue_) { + data_info_queue_ptr_ = std::make_unique(kDataInfoQueueCapacity); + RETURN_IF_NOT_OK(data_info_queue_ptr_->Register(tree_->AllTasks())); + } RETURN_IF_NOT_OK(SendDataToAscend()); #endif } else if (device_type_ == DeviceType::GPU) { @@ -142,6 +147,13 @@ Status DeviceQueueOp::SendDataToAscend() { return Status(StatusCode::kTDTPushFailure, "TDT Push Failed"); } } + if (create_data_info_queue_) { + DATA_INFO data_info; + (void)std::transform( + currRow.begin(), currRow.end(), std::back_inserter(data_info), + [](const std::shared_ptr &ts) { return std::make_pair(ts->type(), ts->shape()); }); + RETURN_IF_NOT_OK(data_info_queue_ptr_->Add(data_info)); + } if (isProfilingEnable) { end_time = ProfilingTime::GetCurMilliSecond(); @@ -157,6 +169,7 @@ Status DeviceQueueOp::SendDataToAscend() { profiling_node->Record(CONNECTOR_DEPTH, connector_capacity, send_batch + 1, connector_size); } send_batch++; + if (total_batch_ > 0 && send_batch >= total_batch_) { is_break_loop = true; break; @@ -196,6 +209,21 @@ Status DeviceQueueOp::SendDataToAscend() { return Status::OK(); } + +#endif + +#ifdef ENABLE_TDTQUE +Status DeviceQueueOp::GetDataInfo(DATA_INFO *data_info) { + if (!create_data_info_queue_) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "DataInfo queue is not created."); + } + RETURN_IF_NOT_OK(data_info_queue_ptr_->PopFront(data_info)); + return Status::OK(); +} +#else +Status DeviceQueueOp::GetDataInfo(DATA_INFO *data_info) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "GetDataInfo is not supported yet."); +} #endif #ifdef ENABLE_GPUQUE diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h index 1a97b1a6030..fb23972f104 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h @@ -18,6 +18,7 @@ #include #include +#include #include #include "minddata/dataset/engine/datasetops/pipeline_op.h" @@ -25,6 +26,7 @@ #include "minddata/dataset/util/status.h" #ifdef ENABLE_TDTQUE +#include "minddata/dataset/util/queue.h" #include "minddata/dataset/engine/tdt/tdt_plugin.h" #endif @@ -37,6 +39,10 @@ using mindspore::device::GpuBufferMgr; namespace mindspore { namespace dataset { + +using DATA_INFO = std::vector>; +using DATA_INFO_QUEUE = Queue; +const int kDataInfoQueueCapacity = 128; class DeviceQueueOp : public PipelineOp { public: static const uint32_t INVALID_HANDLE = 0xffffffffUL; @@ -91,13 +97,18 @@ class DeviceQueueOp : public PipelineOp { return *this; } + Builder &SetCreateDataInfoQueue(bool create_data_info_queue) { + builder_create_data_info_queue_ = create_data_info_queue; + return *this; + } // Name: Build() // Description: The final step for building a DeviceQueueOp via the Builder is // to call this Build() method. It will instantiate the DeviceQueueOp // and return it to caller as a shared pointer. Status Build(std::shared_ptr *ptr) { *ptr = std::make_shared(builder_channel_name_, builder_device_type_, builder_device_id_, - builder_prefetch_size_, builder_send_epoch_end_, builder_total_batch_); + builder_prefetch_size_, builder_send_epoch_end_, builder_total_batch_, + builder_create_data_info_queue_); return Status::OK(); } @@ -108,12 +119,13 @@ class DeviceQueueOp : public PipelineOp { std::string builder_channel_name_; bool builder_send_epoch_end_; int32_t builder_total_batch_; + bool builder_create_data_info_queue_; }; // Name: constructor // Description DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size, - bool send_epoch_end, int32_t total_batch); + bool send_epoch_end, int32_t total_batch, bool create_data_info_queue); // Name: destructor // Description @@ -138,6 +150,8 @@ class DeviceQueueOp : public PipelineOp { void StopWaiting() { ascend_keep_waiting_ = false; } #endif + Status GetDataInfo(DATA_INFO *data_info); + // Name: Print() // Description: A function that prints info about the node void Print(std::ostream &out, // In: The output stream to print to @@ -170,6 +184,7 @@ class DeviceQueueOp : public PipelineOp { #ifdef ENABLE_TDTQUE Status SendDataToAscend(); bool ascend_keep_waiting_; + #endif #ifdef ENABLE_GPUQUE @@ -190,6 +205,8 @@ class DeviceQueueOp : public PipelineOp { const bool send_epoch_end_; bool stop_send_; int32_t total_batch_; + bool create_data_info_queue_; + std::unique_ptr data_info_queue_ptr_; #ifdef ENABLE_TDTQUE std::shared_ptr tdtInstancePtr; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc index dcace2b1764..847646c02f2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc @@ -62,9 +62,8 @@ std::vector> TransferNode::Build() { } else if (device_type_ == "Ascend") { type = DeviceQueueOp::DeviceType::Ascend; } - - node_ops.push_back( - std::make_shared(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_, total_batch_)); + node_ops.push_back(std::make_shared(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_, + total_batch_, false)); return node_ops; } diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index bac4990f03f..be157e80188 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1005,7 +1005,7 @@ class Dataset: return dataset @check_device_send - def device_que(self, prefetch_size=None, send_epoch_end=True): + def device_que(self, prefetch_size=None, send_epoch_end=True, create_data_info_queue=False): """ Return a transferred Dataset that transfers data through a device. @@ -1013,6 +1013,8 @@ class Dataset: prefetch_size (int, optional): Prefetch number of records ahead of the user's request (default=None). send_epoch_end (bool, optional): Whether to send end of sequence to device or not (default=True). + create_data_info_queue (bool, optional): Whether to create queue which stores + types and shapes of data or not(default=False). Note: If device is Ascend, features of data will be transferred one by one. The limitation @@ -1021,15 +1023,17 @@ class Dataset: Return: TransferDataset, dataset for transferring. """ - return self.to_device(send_epoch_end=send_epoch_end) + return self.to_device(send_epoch_end=send_epoch_end, create_data_info_queue=create_data_info_queue) @check_device_send - def to_device(self, send_epoch_end=True): + def to_device(self, send_epoch_end=True, create_data_info_queue=False): """ Transfer data through CPU, GPU or Ascend devices. Args: send_epoch_end (bool, optional): Whether to send end of sequence to device or not (default=True). + create_data_info_queue (bool, optional): Whether to create queue which stores + types and shapes of data or not(default=False). Note: If device is Ascend, features of data will be transferred one by one. The limitation @@ -1078,7 +1082,7 @@ class Dataset: distribution_path, device_id = get_distribution(self) if distribution_path == "": - return TransferDataset(self, queue_name, device_id, device_type, send_epoch_end) + return TransferDataset(self, queue_name, device_id, device_type, send_epoch_end, create_data_info_queue) try: with open(distribution_path, 'r') as distribution_f: dist = json.load(distribution_f) @@ -1088,7 +1092,7 @@ class Dataset: except Exception: raise RuntimeError("Failed to read Distribution file.") - return TransferDataset(self, queue_name, device_id, device_type, send_epoch_end) + return TransferDataset(self, queue_name, device_id, device_type, send_epoch_end, create_data_info_queue) @check_save def save(self, file_name, num_files=1, file_type='mindrecord'): @@ -2640,9 +2644,12 @@ class TransferDataset(DatasetOp): device_id (int): ID of device. device_type (str): Type of device, including "CPU", "GPU", and "Ascend". send_epoch_end (bool, optional): Whether to send end of sequence to device or not (default=True). + create_data_info_queue (bool, optional): Whether to create queue which stores + types and shapes of data or not(default=False). """ - def __init__(self, input_dataset, queue_name, device_id, device_type, send_epoch_end=True): + def __init__(self, input_dataset, queue_name, device_id, device_type, send_epoch_end=True, + create_data_info_queue=False): super().__init__() self.children.append(input_dataset) input_dataset.parent.append(self) @@ -2652,6 +2659,7 @@ class TransferDataset(DatasetOp): self._device_id = device_id self._send_epoch_end = send_epoch_end self.iterator = None + self._create_data_info_queue = create_data_info_queue def get_args(self): args = super().get_args() @@ -2661,6 +2669,7 @@ class TransferDataset(DatasetOp): args["send_epoch_end"] = self._send_epoch_end if hasattr(self.children[0], "__total_batch__"): args["total_batch"] = self.children[0].__total_batch__ + args["create_data_info_queue"] = self._create_data_info_queue return args def create_dict_iterator(self, num_epochs=-1, output_numpy=False): @@ -2692,6 +2701,9 @@ class TransferDataset(DatasetOp): def continue_send(self): self.iterator.depipeline.ContinueSend() + def get_data_info(self): + return self.iterator.depipeline.GetDataInfo() + class RangeDataset(MappableDataset): """ diff --git a/mindspore/train/_utils.py b/mindspore/train/_utils.py index 6e0b32e2c54..ced97b027ca 100644 --- a/mindspore/train/_utils.py +++ b/mindspore/train/_utils.py @@ -50,7 +50,7 @@ def _get_types_and_shapes(dataset): return dataset_types, dataset_shapes -def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'): +def _exec_datagraph(exec_dataset, dataset_size, phase='dataset', create_data_info_queue=False): """Initialize and execute the dataset graph.""" batch_size = exec_dataset.get_batch_size() input_indexs = exec_dataset.input_indexs @@ -58,7 +58,7 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'): # transform data format dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset) send_epoch_end = bool(dataset_size == -1) - exec_dataset = exec_dataset.device_que(send_epoch_end=send_epoch_end) + exec_dataset = exec_dataset.device_que(send_epoch_end=send_epoch_end, create_data_info_queue=create_data_info_queue) _executor.init_dataset(exec_dataset.queue_name, dataset_size, diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index 47483020d2d..bc79dec9104 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -17,6 +17,7 @@ import math import os from mindspore._checkparam import Validator +from mindspore.common.dtype import pytype_to_dtype from .. import context, nn from ._utils import _exec_datagraph, _get_types_and_shapes, _construct_tensor_list from ..nn.wrap import GetNextSingleOp @@ -31,6 +32,7 @@ def _send_data(dataset, epoch_num): exec_dataset.send(epoch_num) dataset.__has_sent__ = True + def _send_data_no_flag(dataset, epoch_num): """Engine dataset to write data to tdt queue directly.""" exec_dataset = dataset.__transfer_dataset__ @@ -70,6 +72,7 @@ def connect_network_with_dataset(network, dataset_helper): Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the dataset channel 'queue_name' and performs the forward computation. """ + def __init__(self, network, dataset_types, dataset_shapes, queue_name): super(_DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags()) # Also copy the flag in `network` construct @@ -88,9 +91,30 @@ def connect_network_with_dataset(network, dataset_helper): if isinstance(dataset_iter, _DatasetIterNormal): raise RuntimeError("Dataset should be connected with network only in sink mode.") - if not hasattr(dataset, '__me_inited__') and (context.get_context("device_target") == "Ascend" - or context.get_context("device_target") == "GPU") and not \ - context.get_context("enable_ge"): + if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \ + and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \ + and context.get_context("device_target") == "Ascend": + + if not hasattr(dataset, '__network__'): + dataset.__network__ = network + network = dataset.__network__ + + dataset_types, dataset_shapes = dataset_helper.get_data_info() + dataset_types = [pytype_to_dtype(x) for x in dataset_types] + + key = str(dataset_types) + str(dataset_shapes) + if hasattr(dataset, '__network_manage__') and key in dataset.__network_manage__: + network = dataset.__network_manage__[key] + else: + network = _DataWrapper(network, dataset_types, dataset_shapes, dataset.__transfer_dataset__.queue_name) + dataset.__network_manage__ = dataset.__network_manage__ if hasattr( + dataset, '__network_manage__') else dict() + dataset.__network_manage__[key] = network + + return network + + if not hasattr(dataset, '__me_inited__') and (context.get_context("device_target") == "Ascend" or \ + context.get_context("device_target") == "GPU") and not context.get_context("enable_ge"): dataset.__me_inited__ = True dataset_types, dataset_shapes = dataset_helper.types_shapes() @@ -99,7 +123,6 @@ def connect_network_with_dataset(network, dataset_helper): network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name) return network - class DatasetHelper: """ DatasetHelper is a class to process the MindData dataset and it provides the information of dataset. @@ -171,18 +194,25 @@ class DatasetHelper: """continue send data to device at the beginning of epoch.""" self.iter.continue_send() + def get_data_info(self): + return self.iter.get_data_info() + class _DatasetIter: """Base iter for dataset helper""" + def __init__(self, dataset, sink_size, epoch_num): self.dataset = dataset self.sink_size = sink_size - self.sink_count = 1 + self.sink_count = self.get_sink_count(dataset) if not hasattr(dataset, '__transfer_dataset__'): if hasattr(dataset, '__loop_size__'): self.sink_size = dataset.__loop_size__ - dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size) + create_data_info_queue = (sink_size == 1 and self.sink_count == 1 and context.get_context( + "device_target") == "Ascend") + dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size, + create_data_info_queue=create_data_info_queue) if not hasattr(dataset, '__no_send__'): _send_data(dataset, epoch_num) @@ -191,6 +221,7 @@ class _DatasetIter: self.stop_send = dataset.__transfer_dataset__.stop_send self.continue_send = dataset.__transfer_dataset__.continue_send + self.get_data_info = dataset.__transfer_dataset__.get_data_info self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset) def __iter__(self): @@ -223,7 +254,7 @@ class _DatasetIter: sink_size = self.dataset.__loop_size__ else: if context.get_context("enable_ge") or context.get_context("device_target") == "Ascend" \ - or context.get_context("device_target") == "GPU": + or context.get_context("device_target") == "GPU": if self.sink_size > 0: sink_size = self.sink_size else: @@ -233,6 +264,7 @@ class _DatasetIter: class _DatasetIterGE(_DatasetIter): """Iter for GE.""" + def __init__(self, dataset, sink_size, epoch_num): super().__init__(dataset, sink_size, epoch_num) self.sink_count = self.get_sink_count(dataset) @@ -249,6 +281,7 @@ class _DatasetIterGE(_DatasetIter): class _DatasetIterMSLoopSink(_DatasetIter): """Iter for context (device_target=Ascend)""" + def __init__(self, dataset, sink_size, epoch_num): super().__init__(dataset, sink_size, epoch_num) self.sink_count = self.get_sink_count(dataset) @@ -270,6 +303,7 @@ class _DatasetIterMSLoopSink(_DatasetIter): class _DatasetIterMS(_DatasetIter): """Iter for MS(enable_loop_sink=False).""" + def __init__(self, dataset, sink_size, epoch_num): super().__init__(dataset, sink_size, epoch_num) if sink_size > 0: @@ -283,11 +317,13 @@ class _DatasetIterMS(_DatasetIter): class _DatasetIterPSLite(_DatasetIter): """Iter for context (device_target=GPU) on MS_PSERVER or MS_SCHED""" + def __init__(self, dataset, sink_size, epoch_num): super().__init__(dataset, sink_size, epoch_num) self.sink_count = 1 self.sink_size = 1 self.op = None + def op(): return _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num=1) self.op = op diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 63f3b73ddc1..113dbe8b038 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -250,11 +250,14 @@ class Model: scaling_sens /= self._device_number return scaling_sens - def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1): + def _exec_preprocess(self, network, is_train, phase, dataset, + dataset_sink_mode, sink_size=-1, epoch_num=1, dataset_helper=None): """Initializes dataset.""" if dataset_sink_mode and not is_train: dataset.__loop_size__ = 1 - dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num) + + if dataset_helper is None: + dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num) if dataset_sink_mode: network = connect_network_with_dataset(network, dataset_helper) @@ -405,15 +408,6 @@ class Model: epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size()) train_dataset.__total_batch__ = epoch * sink_size - dataset_helper, train_network = self._exec_preprocess(self._train_network, - is_train=True, - phase='train', - dataset=train_dataset, - dataset_sink_mode=True, - sink_size=sink_size, - epoch_num=epoch_num) - self._train_network = train_network - cb_params.train_network = self._train_network cb_params.cur_step_num = 0 run_context = RunContext(cb_params) @@ -421,9 +415,21 @@ class Model: # used to stop training for early stop, such as stopAtTIme or stopATStep should_stop = False + dataset_helper = None for i in range(epoch): cb_params.cur_epoch_num = i + 1 list_callback.epoch_begin(run_context) + dataset_helper, train_network = self._exec_preprocess(self._train_network, + is_train=True, + phase='train', + dataset=train_dataset, + dataset_sink_mode=True, + sink_size=sink_size, + epoch_num=epoch_num, + dataset_helper=dataset_helper) + + self._train_network = train_network + cb_params.train_network = self._train_network # for data sink dataset_helper only iter once, other wise iter epoch_size times. for inputs in dataset_helper: diff --git a/model_zoo/official/nlp/gpt/src/pre_process.py b/model_zoo/official/nlp/gpt/src/pre_process.py index 5aeb7821e18..dbccbf9f5b2 100644 --- a/model_zoo/official/nlp/gpt/src/pre_process.py +++ b/model_zoo/official/nlp/gpt/src/pre_process.py @@ -133,7 +133,7 @@ def tokenize_lambada(file_path): with open(file_path, 'r', encoding='utf-8') as f: for line in f.readlines(): para = json.loads(line)['text'].replace( - "“", '""').replace("”", '"').strip().strip(".") + "“", '"').replace("”", '"').strip().strip(".") tokenized_text = tokenizer.tokenize(para) content += tokenizer.convert_tokens_to_ids(tokenized_text) + [EOT] for chunk in chunks(content, SEQ_LEN): diff --git a/tests/dataset_mock.py b/tests/dataset_mock.py index 988f15dd8c7..dfb6902581c 100644 --- a/tests/dataset_mock.py +++ b/tests/dataset_mock.py @@ -50,7 +50,7 @@ class MindData: def input_indexs(self): return self._input_indexs - def device_que(self, send_epoch_end=True): + def device_que(self, send_epoch_end=True, create_data_info_queue=False): self.queue_name = '6ba41974-209e-11ea-88b0-a24efeb2c736' self.send_epoch_end = send_epoch_end return self @@ -61,6 +61,9 @@ class MindData: def send(self, num_epochs=-1): pass + def get_data_info(self): + pass + def stop_send(self): pass