pr to master #8
|
@ -40,13 +40,21 @@ def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0,
|
|||
else:
|
||||
columns_list = ["input_ids", "input_mask", "segment_ids"]
|
||||
|
||||
shard_equal_rows = True
|
||||
shuffle = (do_shuffle == "true")
|
||||
if device_num == 1:
|
||||
shard_equal_rows = False
|
||||
shuffle = False
|
||||
|
||||
if data_type == DataType.MINDRECORD:
|
||||
ds = de.MindDataset(data_files, columns_list=columns_list,
|
||||
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank)
|
||||
else:
|
||||
ds = de.TFRecordDataset(data_files, schema_dir, columns_list=columns_list,
|
||||
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
|
||||
shard_equal_rows=True)
|
||||
shuffle=shuffle, num_shards=device_num, shard_id=rank,
|
||||
shard_equal_rows=shard_equal_rows)
|
||||
if device_num == 1 and shuffle is True:
|
||||
ds = ds.shuffle(10000)
|
||||
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
ds = ds.map(operations=type_cast_op, input_columns="segment_ids")
|
||||
|
|
Loading…
Reference in New Issue