forked from mindspore-Ecosystem/mindspore
!13170 padding sample last batch indices to be the same length with previous batch
From: @zhouneng2 Reviewed-by: @linqingke Signed-off-by: @linqingke
This commit is contained in:
commit
a873ef5b9f
|
@ -22,7 +22,7 @@ import pickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from mindspore.dataset import GeneratorDataset
|
from mindspore.dataset import GeneratorDataset, Sampler
|
||||||
|
|
||||||
import src.constants as rconst
|
import src.constants as rconst
|
||||||
import src.movielens as movielens
|
import src.movielens as movielens
|
||||||
|
@ -214,6 +214,7 @@ class NCFDataset:
|
||||||
total_negatives,
|
total_negatives,
|
||||||
index_bounds,
|
index_bounds,
|
||||||
sorted_train_pos_items,
|
sorted_train_pos_items,
|
||||||
|
num_neg,
|
||||||
is_training=True):
|
is_training=True):
|
||||||
self._pos_users = pos_users
|
self._pos_users = pos_users
|
||||||
self._pos_items = pos_items
|
self._pos_items = pos_items
|
||||||
|
@ -234,6 +235,10 @@ class NCFDataset:
|
||||||
self._eval_users_per_batch = int(
|
self._eval_users_per_batch = int(
|
||||||
batch_size // (1 + rconst.NUM_EVAL_NEGATIVES))
|
batch_size // (1 + rconst.NUM_EVAL_NEGATIVES))
|
||||||
|
|
||||||
|
_pos_count = pos_users.shape[0]
|
||||||
|
_num_samples = (1 + num_neg) * _pos_count
|
||||||
|
self.dataset_len = math.ceil(_num_samples / batch_size)
|
||||||
|
|
||||||
def lookup_negative_items(self, negative_users):
|
def lookup_negative_items(self, negative_users):
|
||||||
"""Lookup negative items"""
|
"""Lookup negative items"""
|
||||||
output = np.zeros(shape=negative_users.shape, dtype=rconst.ITEM_DTYPE) - 1
|
output = np.zeros(shape=negative_users.shape, dtype=rconst.ITEM_DTYPE) - 1
|
||||||
|
@ -402,8 +407,14 @@ class NCFDataset:
|
||||||
|
|
||||||
return self._get_eval_item(index)
|
return self._get_eval_item(index)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
Return length of the dataset, i.e., the number of batches for an epoch
|
||||||
|
"""
|
||||||
|
return self.dataset_len
|
||||||
|
|
||||||
class RandomSampler:
|
|
||||||
|
class RandomSampler(Sampler):
|
||||||
"""
|
"""
|
||||||
A random sampler for dataset.
|
A random sampler for dataset.
|
||||||
"""
|
"""
|
||||||
|
@ -413,6 +424,7 @@ class RandomSampler:
|
||||||
self._num_samples = (1 + num_train_negatives) * self.pos_count
|
self._num_samples = (1 + num_train_negatives) * self.pos_count
|
||||||
self._batch_size = batch_size
|
self._batch_size = batch_size
|
||||||
self._num_batches = math.ceil(self._num_samples / self._batch_size)
|
self._num_batches = math.ceil(self._num_samples / self._batch_size)
|
||||||
|
super().__init__(self._num_batches)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
"""
|
"""
|
||||||
|
@ -421,13 +433,14 @@ class RandomSampler:
|
||||||
indices = stat_utils.permutation((self._num_samples, stat_utils.random_int32()))
|
indices = stat_utils.permutation((self._num_samples, stat_utils.random_int32()))
|
||||||
|
|
||||||
batch_indices = [indices[x * self._batch_size:(x + 1) * self._batch_size] for x in range(self._num_batches)]
|
batch_indices = [indices[x * self._batch_size:(x + 1) * self._batch_size] for x in range(self._num_batches)]
|
||||||
return iter(batch_indices)
|
|
||||||
|
|
||||||
def __len__(self):
|
# padding last batch indices if necessary
|
||||||
"""
|
if len(batch_indices) > 2 and len(batch_indices[-2]) != len(batch_indices[-1]):
|
||||||
Return length of the sampler, i.e., the number of batches for an epoch.
|
pad_nums = len(batch_indices[-2]) - len(batch_indices[-1])
|
||||||
"""
|
pad_indices = np.random.randint(0, self._num_samples, pad_nums)
|
||||||
return self._num_batches
|
batch_indices[-1] = np.hstack((batch_indices[-1], pad_indices))
|
||||||
|
|
||||||
|
return iter(batch_indices)
|
||||||
|
|
||||||
|
|
||||||
class DistributedSamplerOfTrain:
|
class DistributedSamplerOfTrain:
|
||||||
|
@ -467,7 +480,7 @@ class DistributedSamplerOfTrain:
|
||||||
return self._batchs_per_rank
|
return self._batchs_per_rank
|
||||||
|
|
||||||
|
|
||||||
class SequenceSampler:
|
class SequenceSampler(Sampler):
|
||||||
"""
|
"""
|
||||||
A sequence sampler for dataset.
|
A sequence sampler for dataset.
|
||||||
"""
|
"""
|
||||||
|
@ -478,10 +491,18 @@ class SequenceSampler:
|
||||||
self._eval_elements_in_epoch = num_users * (1 + rconst.NUM_EVAL_NEGATIVES)
|
self._eval_elements_in_epoch = num_users * (1 + rconst.NUM_EVAL_NEGATIVES)
|
||||||
self._eval_batches_per_epoch = self.count_batches(
|
self._eval_batches_per_epoch = self.count_batches(
|
||||||
self._eval_elements_in_epoch, eval_batch_size)
|
self._eval_elements_in_epoch, eval_batch_size)
|
||||||
|
super().__init__(self._eval_batches_per_epoch)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
indices = [(x * self._eval_users_per_batch, (x + 1) * self._eval_users_per_batch)
|
indices = [(x * self._eval_users_per_batch, (x + 1) * self._eval_users_per_batch)
|
||||||
for x in range(self._eval_batches_per_epoch)]
|
for x in range(self._eval_batches_per_epoch)]
|
||||||
|
|
||||||
|
# padding last batch indices if necessary
|
||||||
|
if len(indices) > 2 and len(indices[-2]) != len(indices[-1]):
|
||||||
|
pad_nums = len(indices[-2]) - len(indices[-1])
|
||||||
|
pad_indices = np.random.randint(0, self._eval_elements_in_epoch, pad_nums)
|
||||||
|
indices[-1] = np.hstack((indices[-1], pad_indices))
|
||||||
|
|
||||||
return iter(indices)
|
return iter(indices)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -490,12 +511,6 @@ class SequenceSampler:
|
||||||
x = (example_count + batch_size - 1) // batch_size
|
x = (example_count + batch_size - 1) // batch_size
|
||||||
return (x + batches_per_step - 1) // batches_per_step * batches_per_step
|
return (x + batches_per_step - 1) // batches_per_step * batches_per_step
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
"""
|
|
||||||
Return the length of the sampler, i,e, the number of batches in an epoch.
|
|
||||||
"""
|
|
||||||
return self._eval_batches_per_epoch
|
|
||||||
|
|
||||||
|
|
||||||
class DistributedSamplerOfEval:
|
class DistributedSamplerOfEval:
|
||||||
"""
|
"""
|
||||||
|
@ -562,7 +577,7 @@ def create_dataset(test_train=True, data_dir='./dataset/', dataset='ml-1m', trai
|
||||||
print(train_pos_users, train_pos_items, num_users, num_items, batch_size, total_negatives, index_bounds,
|
print(train_pos_users, train_pos_items, num_users, num_items, batch_size, total_negatives, index_bounds,
|
||||||
sorted_train_pos_items)
|
sorted_train_pos_items)
|
||||||
dataset = NCFDataset(train_pos_users, train_pos_items, num_users, num_items, batch_size, total_negatives,
|
dataset = NCFDataset(train_pos_users, train_pos_items, num_users, num_items, batch_size, total_negatives,
|
||||||
index_bounds, sorted_train_pos_items)
|
index_bounds, sorted_train_pos_items, num_neg)
|
||||||
sampler = RandomSampler(train_pos_users.shape[0], num_neg, batch_size)
|
sampler = RandomSampler(train_pos_users.shape[0], num_neg, batch_size)
|
||||||
if rank_id is not None and rank_size is not None:
|
if rank_id is not None and rank_size is not None:
|
||||||
sampler = DistributedSamplerOfTrain(train_pos_users.shape[0], num_neg, batch_size, rank_id, rank_size)
|
sampler = DistributedSamplerOfTrain(train_pos_users.shape[0], num_neg, batch_size, rank_id, rank_size)
|
||||||
|
@ -585,7 +600,7 @@ def create_dataset(test_train=True, data_dir='./dataset/', dataset='ml-1m', trai
|
||||||
eval_batch_size = parse_eval_batch_size(eval_batch_size=eval_batch_size)
|
eval_batch_size = parse_eval_batch_size(eval_batch_size=eval_batch_size)
|
||||||
dataset = NCFDataset(eval_pos_users, eval_pos_items, num_users, num_items,
|
dataset = NCFDataset(eval_pos_users, eval_pos_items, num_users, num_items,
|
||||||
eval_batch_size, total_negatives, index_bounds,
|
eval_batch_size, total_negatives, index_bounds,
|
||||||
sorted_train_pos_items, is_training=False)
|
sorted_train_pos_items, num_neg, is_training=False)
|
||||||
sampler = SequenceSampler(eval_batch_size, num_users)
|
sampler = SequenceSampler(eval_batch_size, num_users)
|
||||||
|
|
||||||
ds = GeneratorDataset(dataset,
|
ds = GeneratorDataset(dataset,
|
||||||
|
|
Loading…
Reference in New Issue