forked from mindspore-Ecosystem/mindspore
!23150 improve NCF training speed
Merge pull request !23150 from zhouneng/code_docs_fix_issue_I48UPH
This commit is contained in:
commit
7a3a1dbb28
|
@ -579,20 +579,13 @@ def create_dataset(test_train=True, data_dir='./dataset/', dataset='ml-1m', trai
|
|||
sampler = RandomSampler(train_pos_users.shape[0], num_neg, batch_size)
|
||||
if rank_id is not None and rank_size is not None:
|
||||
sampler = DistributedSamplerOfTrain(train_pos_users.shape[0], num_neg, batch_size, rank_id, rank_size)
|
||||
if dataset == 'ml-20m':
|
||||
ds = GeneratorDataset(dataset,
|
||||
column_names=[movielens.USER_COLUMN,
|
||||
movielens.ITEM_COLUMN,
|
||||
"labels",
|
||||
rconst.VALID_POINT_MASK],
|
||||
sampler=sampler, num_parallel_workers=32, python_multiprocessing=False)
|
||||
else:
|
||||
ds = GeneratorDataset(dataset,
|
||||
column_names=[movielens.USER_COLUMN,
|
||||
movielens.ITEM_COLUMN,
|
||||
"labels",
|
||||
rconst.VALID_POINT_MASK],
|
||||
sampler=sampler)
|
||||
|
||||
ds = GeneratorDataset(dataset,
|
||||
column_names=[movielens.USER_COLUMN,
|
||||
movielens.ITEM_COLUMN,
|
||||
"labels",
|
||||
rconst.VALID_POINT_MASK],
|
||||
sampler=sampler)
|
||||
|
||||
else:
|
||||
eval_batch_size = parse_eval_batch_size(eval_batch_size=eval_batch_size)
|
||||
|
|
Loading…
Reference in New Issue