!12154 Fix samplers bugs
From: @hfarahat Reviewed-by: @heleiwang,@liucunwei Signed-off-by: @liucunwei
This commit is contained in:
commit
bf528c6817
|
@ -193,9 +193,9 @@ Status GeneratorOp::operator()() {
|
|||
TaskManager::FindMe()->Post();
|
||||
RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks()));
|
||||
std::unique_ptr<DataBuffer> fetched_buffer;
|
||||
int64_t num_rows_sampled = sampler_ ? sampler_->CalculateNumSamples(num_rows_) : num_rows_;
|
||||
RETURN_IF_NOT_OK(Init());
|
||||
|
||||
int64_t num_rows_sampled = sampler_ ? sampler_->CalculateNumSamples(num_rows_) : num_rows_;
|
||||
bool eof = false;
|
||||
while (!eof) {
|
||||
// Create new buffer each iteration
|
||||
|
|
|
@ -184,7 +184,6 @@ int64_t DistributedSamplerRT::CalculateNumSamples(int64_t num_rows) {
|
|||
if (device_id_ < remainder) shard_size++;
|
||||
if (device_id_ < offset_) shard_size--;
|
||||
} else {
|
||||
offset_ = 0;
|
||||
shard_size = (child_num_rows + num_devices_ - 1) / num_devices_;
|
||||
}
|
||||
// add 1 to an empty shard
|
||||
|
|
|
@ -42,7 +42,12 @@ GeneratorNode::GeneratorNode(py::function generator_function, const std::vector<
|
|||
|
||||
GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema,
|
||||
int64_t source_len, std::shared_ptr<SamplerObj> sampler)
|
||||
: MappableSourceNode(), generator_function_(generator_function), schema_(schema), reset_ancestor_(nullptr) {}
|
||||
: MappableSourceNode(),
|
||||
generator_function_(generator_function),
|
||||
schema_(schema),
|
||||
reset_ancestor_(nullptr),
|
||||
sampler_(std::move(sampler)),
|
||||
source_len_(source_len) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> GeneratorNode::Copy() {
|
||||
std::shared_ptr<GeneratorNode> node;
|
||||
|
|
|
@ -2233,6 +2233,7 @@ _GLOBAL_PYFUNC_LIST = []
|
|||
_OP_NAME = dict()
|
||||
_OP_PROCESS = dict()
|
||||
|
||||
|
||||
# Pyfunc worker init function
|
||||
# Python multiprocessing library forbid sending lambda function through pipe.
|
||||
# This init function allow us to add all Python function to a global collection and then fork afterwards.
|
||||
|
@ -3781,6 +3782,8 @@ class GeneratorDataset(MappableDataset):
|
|||
try:
|
||||
new_op.sampler = None
|
||||
new_op.sample_fn = sample_fn
|
||||
new_op.source_len = min(new_op.source_len,
|
||||
new_op.num_samples) if new_op.num_samples is not None else new_op.source_len
|
||||
iter(self.source)
|
||||
except TypeError:
|
||||
# Use generator function if input callable
|
||||
|
|
Loading…
Reference in New Issue