fix boost bert timeout

This commit is contained in:
luoyang 2022-04-07 14:08:25 +08:00
parent 04c067500d
commit 064cd9b999
2 changed files with 7 additions and 5 deletions

View File

@ -3632,6 +3632,7 @@ class TransferDataset(Dataset):
self._send_epoch_end = replace_none(send_epoch_end, True)
self._create_data_info_queue = create_data_info_queue
self._to_device = None
self.column_name = self.get_col_names()
def parse(self, children=None):
total_batch = 0
@ -3690,7 +3691,7 @@ class TransferDataset(Dataset):
def get_offload_model(self):
if self._to_device is not None:
return self._to_device.get_offload_model(self.get_col_names())
return self._to_device.get_offload_model(self.column_name)
raise RuntimeError("get_offload_model, _to_device is None")

View File

@ -32,10 +32,11 @@ def check_add_offload_sink_mode(dataset, dataset_helper, network):
if hasattr(dataset, '__no_send__'):
# Dataset was not sent to device. Skip adding offload.
return network
# We don't use dataset.__transfer_dataset__ because there will be a device_queue rdr warning log
iterator = dataset.create_tuple_iterator(num_epochs=1)
if iterator.offload_model is not None:
network = ApplyPreTransform(iterator.offload_model, network)
offload_model = dataset.__transfer_dataset__.get_offload_model()
# See if the offload pass identified any operations to be offloaded
if offload_model.transform_list != []:
check_concat_zip_dataset(dataset.__transfer_dataset__)
network = ApplyPreTransform(offload_model, network)
return network