fix generator or user defined sampler len method unmatch iter method
This commit is contained in:
parent
5e5489d59f
commit
50b783ee13
|
@ -49,14 +49,16 @@ Status GeneratorOp::Builder::Build(std::shared_ptr<GeneratorOp> *ptr) {
|
||||||
|
|
||||||
GeneratorOp::GeneratorOp(py::function generator_function, std::vector<std::string> column_names,
|
GeneratorOp::GeneratorOp(py::function generator_function, std::vector<std::string> column_names,
|
||||||
std::vector<DataType> column_types, int32_t prefetch_size, int32_t buffer_size,
|
std::vector<DataType> column_types, int32_t prefetch_size, int32_t buffer_size,
|
||||||
int32_t connector_size)
|
int32_t connector_size, int64_t pre_counter_size)
|
||||||
: PipelineOp(connector_size),
|
: PipelineOp(connector_size),
|
||||||
generator_function_(generator_function),
|
generator_function_(generator_function),
|
||||||
column_names_(column_names),
|
column_names_(column_names),
|
||||||
column_types_(column_types),
|
column_types_(column_types),
|
||||||
prefetch_size_(prefetch_size),
|
prefetch_size_(prefetch_size),
|
||||||
buffer_size_(buffer_size),
|
buffer_size_(buffer_size),
|
||||||
buffer_id_(0) {}
|
pre_counter_size_(pre_counter_size),
|
||||||
|
buffer_id_(0),
|
||||||
|
generator_counter_(0) {}
|
||||||
|
|
||||||
GeneratorOp::~GeneratorOp() { this->Dealloc(); }
|
GeneratorOp::~GeneratorOp() { this->Dealloc(); }
|
||||||
|
|
||||||
|
@ -146,6 +148,7 @@ Status GeneratorOp::FillBuffer(TensorQTable *tt) {
|
||||||
TensorRow row;
|
TensorRow row;
|
||||||
RETURN_IF_NOT_OK(PyRowToTensorRow(generator_.attr("__next__")(), &row));
|
RETURN_IF_NOT_OK(PyRowToTensorRow(generator_.attr("__next__")(), &row));
|
||||||
tt->push_back(std::move(row));
|
tt->push_back(std::move(row));
|
||||||
|
generator_counter_++;
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -209,6 +212,13 @@ Status GeneratorOp::operator()() {
|
||||||
if (!eoe) {
|
if (!eoe) {
|
||||||
return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, e.what());
|
return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, e.what());
|
||||||
}
|
}
|
||||||
|
if (pre_counter_size_ != -1 && pre_counter_size_ != generator_counter_) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "The actual amount of data read from generator " << generator_counter_
|
||||||
|
<< " is different from generator.len " << pre_counter_size_
|
||||||
|
<< ", you should adjust generator.len to make them match.";
|
||||||
|
return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, ss.str());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (fetched_table->size() > 0) {
|
if (fetched_table->size() > 0) {
|
||||||
|
@ -254,6 +264,7 @@ Status GeneratorOp::Reset() {
|
||||||
// Wake up master thread
|
// Wake up master thread
|
||||||
wp_.Set();
|
wp_.Set();
|
||||||
}
|
}
|
||||||
|
generator_counter_ = 0;
|
||||||
return Status(StatusCode::kOK, "GeneratorOp Reset Succeed");
|
return Status(StatusCode::kOK, "GeneratorOp Reset Succeed");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -93,7 +93,8 @@ class GeneratorOp : public PipelineOp {
|
||||||
};
|
};
|
||||||
|
|
||||||
GeneratorOp(py::function generator_function, std::vector<std::string> column_names,
|
GeneratorOp(py::function generator_function, std::vector<std::string> column_names,
|
||||||
std::vector<DataType> column_types, int32_t prefetch_size, int32_t buffer_size, int32_t connector_size);
|
std::vector<DataType> column_types, int32_t prefetch_size, int32_t buffer_size, int32_t connector_size,
|
||||||
|
int64_t pre_counter_size = 0);
|
||||||
|
|
||||||
~GeneratorOp();
|
~GeneratorOp();
|
||||||
|
|
||||||
|
@ -142,6 +143,8 @@ class GeneratorOp : public PipelineOp {
|
||||||
std::vector<DataType> column_types_;
|
std::vector<DataType> column_types_;
|
||||||
int32_t prefetch_size_;
|
int32_t prefetch_size_;
|
||||||
int32_t buffer_size_;
|
int32_t buffer_size_;
|
||||||
|
int64_t pre_counter_size_;
|
||||||
|
int64_t generator_counter_;
|
||||||
|
|
||||||
py::object generator_;
|
py::object generator_;
|
||||||
int32_t buffer_id_;
|
int32_t buffer_id_;
|
||||||
|
|
|
@ -46,6 +46,7 @@ std::shared_ptr<DatasetNode> GeneratorNode::Copy() {
|
||||||
} else {
|
} else {
|
||||||
node = std::make_shared<GeneratorNode>(generator_function_, schema_);
|
node = std::make_shared<GeneratorNode>(generator_function_, schema_);
|
||||||
}
|
}
|
||||||
|
node->SetGeneratorDatasetSize(dataset_size_);
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,7 +73,7 @@ Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_
|
||||||
// GeneratorOp's constructor takes in a prefetch_size, which isn't being set by user nor is it being used by
|
// GeneratorOp's constructor takes in a prefetch_size, which isn't being set by user nor is it being used by
|
||||||
// GeneratorOp internally. Here it is given a zero which is the default in generator builder
|
// GeneratorOp internally. Here it is given a zero which is the default in generator builder
|
||||||
std::shared_ptr<GeneratorOp> op = std::make_shared<GeneratorOp>(generator_function_, column_names_, column_types_, 0,
|
std::shared_ptr<GeneratorOp> op = std::make_shared<GeneratorOp>(generator_function_, column_names_, column_types_, 0,
|
||||||
rows_per_buffer_, connector_que_size_);
|
rows_per_buffer_, connector_que_size_, dataset_size_);
|
||||||
|
|
||||||
// Init() is called in builder when generator is built. Here, since we are getting away from the builder class, init
|
// Init() is called in builder when generator is built. Here, since we are getting away from the builder class, init
|
||||||
// needs to be called when the op is built. The caveat is that Init needs to be made public (before it is private).
|
// needs to be called when the op is built. The caveat is that Init needs to be made public (before it is private).
|
||||||
|
|
|
@ -2663,7 +2663,7 @@ class ConcatDataset(Dataset):
|
||||||
|
|
||||||
tem_sampler = copy.deepcopy(sampler)
|
tem_sampler = copy.deepcopy(sampler)
|
||||||
tem_sampler.set_offset(cumulative_samples_nums)
|
tem_sampler.set_offset(cumulative_samples_nums)
|
||||||
child.sampler = tem_sampler
|
child.use_sampler(tem_sampler)
|
||||||
|
|
||||||
cumulative_samples_nums += self.children_sizes_[index]
|
cumulative_samples_nums += self.children_sizes_[index]
|
||||||
cumulative_samples_nums %= sampler.num_shards
|
cumulative_samples_nums %= sampler.num_shards
|
||||||
|
@ -3808,6 +3808,8 @@ class GeneratorDataset(MappableDataset):
|
||||||
self.dataset_size = math.ceil(len(self.source) / self.num_shards)
|
self.dataset_size = math.ceil(len(self.source) / self.num_shards)
|
||||||
|
|
||||||
rows_from_sampler = self._get_sampler_dataset_size()
|
rows_from_sampler = self._get_sampler_dataset_size()
|
||||||
|
if self.num_samples is not None and self.num_samples < rows_from_sampler:
|
||||||
|
rows_from_sampler = self.num_samples
|
||||||
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
|
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
|
||||||
self.dataset_size = rows_from_sampler
|
self.dataset_size = rows_from_sampler
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue