fix generator or user defined sampler len method unmatch iter method
This commit is contained in:
parent
5e5489d59f
commit
50b783ee13
|
@ -49,14 +49,16 @@ Status GeneratorOp::Builder::Build(std::shared_ptr<GeneratorOp> *ptr) {
|
|||
|
||||
GeneratorOp::GeneratorOp(py::function generator_function, std::vector<std::string> column_names,
|
||||
std::vector<DataType> column_types, int32_t prefetch_size, int32_t buffer_size,
|
||||
int32_t connector_size)
|
||||
int32_t connector_size, int64_t pre_counter_size)
|
||||
: PipelineOp(connector_size),
|
||||
generator_function_(generator_function),
|
||||
column_names_(column_names),
|
||||
column_types_(column_types),
|
||||
prefetch_size_(prefetch_size),
|
||||
buffer_size_(buffer_size),
|
||||
buffer_id_(0) {}
|
||||
pre_counter_size_(pre_counter_size),
|
||||
buffer_id_(0),
|
||||
generator_counter_(0) {}
|
||||
|
||||
GeneratorOp::~GeneratorOp() { this->Dealloc(); }
|
||||
|
||||
|
@ -146,6 +148,7 @@ Status GeneratorOp::FillBuffer(TensorQTable *tt) {
|
|||
TensorRow row;
|
||||
RETURN_IF_NOT_OK(PyRowToTensorRow(generator_.attr("__next__")(), &row));
|
||||
tt->push_back(std::move(row));
|
||||
generator_counter_++;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -209,6 +212,13 @@ Status GeneratorOp::operator()() {
|
|||
if (!eoe) {
|
||||
return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, e.what());
|
||||
}
|
||||
if (pre_counter_size_ != -1 && pre_counter_size_ != generator_counter_) {
|
||||
std::stringstream ss;
|
||||
ss << "The actual amount of data read from generator " << generator_counter_
|
||||
<< " is different from generator.len " << pre_counter_size_
|
||||
<< ", you should adjust generator.len to make them match.";
|
||||
return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, ss.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
if (fetched_table->size() > 0) {
|
||||
|
@ -254,6 +264,7 @@ Status GeneratorOp::Reset() {
|
|||
// Wake up master thread
|
||||
wp_.Set();
|
||||
}
|
||||
generator_counter_ = 0;
|
||||
return Status(StatusCode::kOK, "GeneratorOp Reset Succeed");
|
||||
}
|
||||
|
||||
|
|
|
@ -93,7 +93,8 @@ class GeneratorOp : public PipelineOp {
|
|||
};
|
||||
|
||||
GeneratorOp(py::function generator_function, std::vector<std::string> column_names,
|
||||
std::vector<DataType> column_types, int32_t prefetch_size, int32_t buffer_size, int32_t connector_size);
|
||||
std::vector<DataType> column_types, int32_t prefetch_size, int32_t buffer_size, int32_t connector_size,
|
||||
int64_t pre_counter_size = 0);
|
||||
|
||||
~GeneratorOp();
|
||||
|
||||
|
@ -142,6 +143,8 @@ class GeneratorOp : public PipelineOp {
|
|||
std::vector<DataType> column_types_;
|
||||
int32_t prefetch_size_;
|
||||
int32_t buffer_size_;
|
||||
int64_t pre_counter_size_;
|
||||
int64_t generator_counter_;
|
||||
|
||||
py::object generator_;
|
||||
int32_t buffer_id_;
|
||||
|
|
|
@ -46,6 +46,7 @@ std::shared_ptr<DatasetNode> GeneratorNode::Copy() {
|
|||
} else {
|
||||
node = std::make_shared<GeneratorNode>(generator_function_, schema_);
|
||||
}
|
||||
node->SetGeneratorDatasetSize(dataset_size_);
|
||||
return node;
|
||||
}
|
||||
|
||||
|
@ -72,7 +73,7 @@ Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_
|
|||
// GeneratorOp's constructor takes in a prefetch_size, which isn't being set by user nor is it being used by
|
||||
// GeneratorOp internally. Here it is given a zero which is the default in generator builder
|
||||
std::shared_ptr<GeneratorOp> op = std::make_shared<GeneratorOp>(generator_function_, column_names_, column_types_, 0,
|
||||
rows_per_buffer_, connector_que_size_);
|
||||
rows_per_buffer_, connector_que_size_, dataset_size_);
|
||||
|
||||
// Init() is called in builder when generator is built. Here, since we are getting away from the builder class, init
|
||||
// needs to be called when the op is built. The caveat is that Init needs to be made public (before it is private).
|
||||
|
|
|
@ -2663,7 +2663,7 @@ class ConcatDataset(Dataset):
|
|||
|
||||
tem_sampler = copy.deepcopy(sampler)
|
||||
tem_sampler.set_offset(cumulative_samples_nums)
|
||||
child.sampler = tem_sampler
|
||||
child.use_sampler(tem_sampler)
|
||||
|
||||
cumulative_samples_nums += self.children_sizes_[index]
|
||||
cumulative_samples_nums %= sampler.num_shards
|
||||
|
@ -3808,6 +3808,8 @@ class GeneratorDataset(MappableDataset):
|
|||
self.dataset_size = math.ceil(len(self.source) / self.num_shards)
|
||||
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
if self.num_samples is not None and self.num_samples < rows_from_sampler:
|
||||
rows_from_sampler = self.num_samples
|
||||
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
|
||||
self.dataset_size = rows_from_sampler
|
||||
|
||||
|
|
Loading…
Reference in New Issue