forked from mindspore-Ecosystem/mindspore
!13170 padding sample last batch indices to be the same length with previous batch
From: @zhouneng2 Reviewed-by: @linqingke Signed-off-by: @linqingke
This commit is contained in:
commit
a873ef5b9f
|
@ -22,7 +22,7 @@ import pickle
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from mindspore.dataset import GeneratorDataset
|
||||
from mindspore.dataset import GeneratorDataset, Sampler
|
||||
|
||||
import src.constants as rconst
|
||||
import src.movielens as movielens
|
||||
|
@ -214,6 +214,7 @@ class NCFDataset:
|
|||
total_negatives,
|
||||
index_bounds,
|
||||
sorted_train_pos_items,
|
||||
num_neg,
|
||||
is_training=True):
|
||||
self._pos_users = pos_users
|
||||
self._pos_items = pos_items
|
||||
|
@ -234,6 +235,10 @@ class NCFDataset:
|
|||
self._eval_users_per_batch = int(
|
||||
batch_size // (1 + rconst.NUM_EVAL_NEGATIVES))
|
||||
|
||||
_pos_count = pos_users.shape[0]
|
||||
_num_samples = (1 + num_neg) * _pos_count
|
||||
self.dataset_len = math.ceil(_num_samples / batch_size)
|
||||
|
||||
def lookup_negative_items(self, negative_users):
|
||||
"""Lookup negative items"""
|
||||
output = np.zeros(shape=negative_users.shape, dtype=rconst.ITEM_DTYPE) - 1
|
||||
|
@ -402,8 +407,14 @@ class NCFDataset:
|
|||
|
||||
return self._get_eval_item(index)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Return length of the dataset, i.e., the number of batches for an epoch
|
||||
"""
|
||||
return self.dataset_len
|
||||
|
||||
class RandomSampler:
|
||||
|
||||
class RandomSampler(Sampler):
|
||||
"""
|
||||
A random sampler for dataset.
|
||||
"""
|
||||
|
@ -413,6 +424,7 @@ class RandomSampler:
|
|||
self._num_samples = (1 + num_train_negatives) * self.pos_count
|
||||
self._batch_size = batch_size
|
||||
self._num_batches = math.ceil(self._num_samples / self._batch_size)
|
||||
super().__init__(self._num_batches)
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
|
@ -421,13 +433,14 @@ class RandomSampler:
|
|||
indices = stat_utils.permutation((self._num_samples, stat_utils.random_int32()))
|
||||
|
||||
batch_indices = [indices[x * self._batch_size:(x + 1) * self._batch_size] for x in range(self._num_batches)]
|
||||
return iter(batch_indices)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Return length of the sampler, i.e., the number of batches for an epoch.
|
||||
"""
|
||||
return self._num_batches
|
||||
# padding last batch indices if necessary
|
||||
if len(batch_indices) > 2 and len(batch_indices[-2]) != len(batch_indices[-1]):
|
||||
pad_nums = len(batch_indices[-2]) - len(batch_indices[-1])
|
||||
pad_indices = np.random.randint(0, self._num_samples, pad_nums)
|
||||
batch_indices[-1] = np.hstack((batch_indices[-1], pad_indices))
|
||||
|
||||
return iter(batch_indices)
|
||||
|
||||
|
||||
class DistributedSamplerOfTrain:
|
||||
|
@ -467,7 +480,7 @@ class DistributedSamplerOfTrain:
|
|||
return self._batchs_per_rank
|
||||
|
||||
|
||||
class SequenceSampler:
|
||||
class SequenceSampler(Sampler):
|
||||
"""
|
||||
A sequence sampler for dataset.
|
||||
"""
|
||||
|
@ -478,10 +491,18 @@ class SequenceSampler:
|
|||
self._eval_elements_in_epoch = num_users * (1 + rconst.NUM_EVAL_NEGATIVES)
|
||||
self._eval_batches_per_epoch = self.count_batches(
|
||||
self._eval_elements_in_epoch, eval_batch_size)
|
||||
super().__init__(self._eval_batches_per_epoch)
|
||||
|
||||
def __iter__(self):
|
||||
indices = [(x * self._eval_users_per_batch, (x + 1) * self._eval_users_per_batch)
|
||||
for x in range(self._eval_batches_per_epoch)]
|
||||
|
||||
# padding last batch indices if necessary
|
||||
if len(indices) > 2 and len(indices[-2]) != len(indices[-1]):
|
||||
pad_nums = len(indices[-2]) - len(indices[-1])
|
||||
pad_indices = np.random.randint(0, self._eval_elements_in_epoch, pad_nums)
|
||||
indices[-1] = np.hstack((indices[-1], pad_indices))
|
||||
|
||||
return iter(indices)
|
||||
|
||||
@staticmethod
|
||||
|
@ -490,12 +511,6 @@ class SequenceSampler:
|
|||
x = (example_count + batch_size - 1) // batch_size
|
||||
return (x + batches_per_step - 1) // batches_per_step * batches_per_step
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Return the length of the sampler, i,e, the number of batches in an epoch.
|
||||
"""
|
||||
return self._eval_batches_per_epoch
|
||||
|
||||
|
||||
class DistributedSamplerOfEval:
|
||||
"""
|
||||
|
@ -562,7 +577,7 @@ def create_dataset(test_train=True, data_dir='./dataset/', dataset='ml-1m', trai
|
|||
print(train_pos_users, train_pos_items, num_users, num_items, batch_size, total_negatives, index_bounds,
|
||||
sorted_train_pos_items)
|
||||
dataset = NCFDataset(train_pos_users, train_pos_items, num_users, num_items, batch_size, total_negatives,
|
||||
index_bounds, sorted_train_pos_items)
|
||||
index_bounds, sorted_train_pos_items, num_neg)
|
||||
sampler = RandomSampler(train_pos_users.shape[0], num_neg, batch_size)
|
||||
if rank_id is not None and rank_size is not None:
|
||||
sampler = DistributedSamplerOfTrain(train_pos_users.shape[0], num_neg, batch_size, rank_id, rank_size)
|
||||
|
@ -585,7 +600,7 @@ def create_dataset(test_train=True, data_dir='./dataset/', dataset='ml-1m', trai
|
|||
eval_batch_size = parse_eval_batch_size(eval_batch_size=eval_batch_size)
|
||||
dataset = NCFDataset(eval_pos_users, eval_pos_items, num_users, num_items,
|
||||
eval_batch_size, total_negatives, index_bounds,
|
||||
sorted_train_pos_items, is_training=False)
|
||||
sorted_train_pos_items, num_neg, is_training=False)
|
||||
sampler = SequenceSampler(eval_batch_size, num_users)
|
||||
|
||||
ds = GeneratorDataset(dataset,
|
||||
|
|
Loading…
Reference in New Issue