!23150 improve NCF training speed

Merge pull request !23150 from zhouneng/code_docs_fix_issue_I48UPH
This commit is contained in:
i-robot 2021-09-09 08:37:53 +00:00 committed by Gitee
commit 7a3a1dbb28
1 changed files with 7 additions and 14 deletions

View File

@ -579,20 +579,13 @@ def create_dataset(test_train=True, data_dir='./dataset/', dataset='ml-1m', trai
sampler = RandomSampler(train_pos_users.shape[0], num_neg, batch_size)
if rank_id is not None and rank_size is not None:
sampler = DistributedSamplerOfTrain(train_pos_users.shape[0], num_neg, batch_size, rank_id, rank_size)
if dataset == 'ml-20m':
ds = GeneratorDataset(dataset,
column_names=[movielens.USER_COLUMN,
movielens.ITEM_COLUMN,
"labels",
rconst.VALID_POINT_MASK],
sampler=sampler, num_parallel_workers=32, python_multiprocessing=False)
else:
ds = GeneratorDataset(dataset,
column_names=[movielens.USER_COLUMN,
movielens.ITEM_COLUMN,
"labels",
rconst.VALID_POINT_MASK],
sampler=sampler)
ds = GeneratorDataset(dataset,
column_names=[movielens.USER_COLUMN,
movielens.ITEM_COLUMN,
"labels",
rconst.VALID_POINT_MASK],
sampler=sampler)
else:
eval_batch_size = parse_eval_batch_size(eval_batch_size=eval_batch_size)