forked from mindspore-Ecosystem/mindspore
!24786 fix some warning when use
Merge pull request !24786 from guozhijian/fix_some_warning
This commit is contained in:
commit
9561a0611c
|
@ -191,18 +191,19 @@ PYBIND_REGISTER(
|
||||||
PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) {
|
PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) {
|
||||||
(void)py::class_<GeneratorNode, DatasetNode, std::shared_ptr<GeneratorNode>>(
|
(void)py::class_<GeneratorNode, DatasetNode, std::shared_ptr<GeneratorNode>>(
|
||||||
*m, "GeneratorNode", "to create a GeneratorNode")
|
*m, "GeneratorNode", "to create a GeneratorNode")
|
||||||
.def(
|
.def(py::init([](py::function generator_function, const std::vector<std::string> &column_names,
|
||||||
py::init([](py::function generator_function, const std::vector<std::string> &column_names,
|
const std::vector<DataType> &column_types, int64_t dataset_len, py::handle sampler,
|
||||||
const std::vector<DataType> &column_types, int64_t dataset_len, py::handle sampler) {
|
uint32_t num_parallel_workers) {
|
||||||
auto gen = std::make_shared<GeneratorNode>(generator_function, column_names, column_types,
|
|
||||||
dataset_len, toSamplerObj(sampler));
|
|
||||||
THROW_IF_ERROR(gen->ValidateParams());
|
|
||||||
return gen;
|
|
||||||
}))
|
|
||||||
.def(py::init([](py::function generator_function, const std::shared_ptr<SchemaObj> schema,
|
|
||||||
int64_t dataset_len, py::handle sampler) {
|
|
||||||
auto gen =
|
auto gen =
|
||||||
std::make_shared<GeneratorNode>(generator_function, schema, dataset_len, toSamplerObj(sampler));
|
std::make_shared<GeneratorNode>(generator_function, column_names, column_types, dataset_len,
|
||||||
|
toSamplerObj(sampler), num_parallel_workers);
|
||||||
|
THROW_IF_ERROR(gen->ValidateParams());
|
||||||
|
return gen;
|
||||||
|
}))
|
||||||
|
.def(py::init([](py::function generator_function, const std::shared_ptr<SchemaObj> schema,
|
||||||
|
int64_t dataset_len, py::handle sampler, uint32_t num_parallel_workers) {
|
||||||
|
auto gen = std::make_shared<GeneratorNode>(generator_function, schema, dataset_len,
|
||||||
|
toSamplerObj(sampler), num_parallel_workers);
|
||||||
THROW_IF_ERROR(gen->ValidateParams());
|
THROW_IF_ERROR(gen->ValidateParams());
|
||||||
return gen;
|
return gen;
|
||||||
}));
|
}));
|
||||||
|
|
|
@ -171,8 +171,10 @@ Status DeviceQueueOp::operator()() {
|
||||||
Status DeviceQueueOp::SendDataToAscend() {
|
Status DeviceQueueOp::SendDataToAscend() {
|
||||||
MS_LOG(INFO) << "Device queue, sending data to Ascend.";
|
MS_LOG(INFO) << "Device queue, sending data to Ascend.";
|
||||||
#ifndef ENABLE_SECURITY
|
#ifndef ENABLE_SECURITY
|
||||||
uint64_t batch_start_time, end_time;
|
uint64_t batch_start_time = 0;
|
||||||
uint64_t batch_record_start, batch_record_end;
|
uint64_t end_time = 0;
|
||||||
|
uint64_t batch_record_start = 0;
|
||||||
|
uint64_t batch_record_end = 0;
|
||||||
#endif
|
#endif
|
||||||
int64_t send_batch = 0;
|
int64_t send_batch = 0;
|
||||||
int32_t tdt_cost = 0;
|
int32_t tdt_cost = 0;
|
||||||
|
@ -221,11 +223,9 @@ Status DeviceQueueOp::SendDataToAscend() {
|
||||||
#ifndef ENABLE_SECURITY
|
#ifndef ENABLE_SECURITY
|
||||||
DetectPerBatchTime(&batch_record_start, &batch_record_end);
|
DetectPerBatchTime(&batch_record_start, &batch_record_end);
|
||||||
#endif
|
#endif
|
||||||
|
PrintBeginInfoWhenFirstBatch(first_push_flag_);
|
||||||
RETURN_IF_NOT_OK(SendRowToTdt(curr_row, is_profiling_enable, &tdt_cost));
|
RETURN_IF_NOT_OK(SendRowToTdt(curr_row, is_profiling_enable, &tdt_cost));
|
||||||
if (first_push_flag_ != true) {
|
PrintEndInfoWhenFirstBatch(&first_push_flag_);
|
||||||
MS_LOG(INFO) << "Loading dataset and push first batch into device successful.";
|
|
||||||
first_push_flag_ = true;
|
|
||||||
}
|
|
||||||
#ifndef ENABLE_SECURITY
|
#ifndef ENABLE_SECURITY
|
||||||
ProfilingRecorder(is_profiling_enable, profiling_node, send_batch, tdt_cost, &batch_start_time, &end_time,
|
ProfilingRecorder(is_profiling_enable, profiling_node, send_batch, tdt_cost, &batch_start_time, &end_time,
|
||||||
connector_capacity, connector_size);
|
connector_capacity, connector_size);
|
||||||
|
@ -581,11 +581,9 @@ Status DeviceQueueOp::SendDataToGPU() {
|
||||||
#ifndef ENABLE_SECURITY
|
#ifndef ENABLE_SECURITY
|
||||||
DetectPerBatchTime(&batch_record_start, &batch_record_end);
|
DetectPerBatchTime(&batch_record_start, &batch_record_end);
|
||||||
#endif
|
#endif
|
||||||
|
PrintBeginInfoWhenFirstBatch(first_push_flag_);
|
||||||
RETURN_IF_NOT_OK(receive_queues_[num_buf++ % num_workers_]->Add(std::move(current_row)));
|
RETURN_IF_NOT_OK(receive_queues_[num_buf++ % num_workers_]->Add(std::move(current_row)));
|
||||||
if (first_push_flag_ != true) {
|
PrintEndInfoWhenFirstBatch(&first_push_flag_);
|
||||||
MS_LOG(INFO) << "Loading dataset and push first batch into device successful.";
|
|
||||||
first_push_flag_ = true;
|
|
||||||
}
|
|
||||||
#ifndef ENABLE_SECURITY
|
#ifndef ENABLE_SECURITY
|
||||||
batch_record_start = ProfilingTime::GetCurMilliSecond();
|
batch_record_start = ProfilingTime::GetCurMilliSecond();
|
||||||
#endif
|
#endif
|
||||||
|
@ -727,6 +725,23 @@ void DeviceQueueOp::DetectPerBatchTime(uint64_t *start_time, uint64_t *end_time)
|
||||||
" performance(with creating dataset iterator) and optimize it.";
|
" performance(with creating dataset iterator) and optimize it.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void DeviceQueueOp::PrintBeginInfoWhenFirstBatch(const bool &first_push_flag) {
|
||||||
|
if (first_push_flag != true) {
|
||||||
|
MS_LOG(INFO) << "Loading dataset and begin to push first batch into device ...";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeviceQueueOp::PrintEndInfoWhenFirstBatch(bool *first_push_flag) {
|
||||||
|
if (!first_push_flag) {
|
||||||
|
MS_LOG(WARNING) << "First batch flag: first_push_flag is nullptr";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (*first_push_flag != true) {
|
||||||
|
MS_LOG(INFO) << "Loading dataset and push first batch into device successful.";
|
||||||
|
*first_push_flag = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -119,10 +119,18 @@ class DeviceQueueOp : public PipelineOp {
|
||||||
// Description: Auto filter metadata column before sending to device.
|
// Description: Auto filter metadata column before sending to device.
|
||||||
Status FilterMetadata(TensorRow *row);
|
Status FilterMetadata(TensorRow *row);
|
||||||
|
|
||||||
// Name: checkExceptions(TensorRow);
|
// Name: CheckExceptions(TensorRow);
|
||||||
// Description: Check whether the TensorRow meets the condition for performing DeviceQueueOp
|
// Description: Check whether the TensorRow meets the condition for performing DeviceQueueOp
|
||||||
Status CheckExceptions(const TensorRow &row) const;
|
Status CheckExceptions(const TensorRow &row) const;
|
||||||
|
|
||||||
|
// Name: PrintBeginInfoWhenFirstBatch(bool)
|
||||||
|
// Description: Print info when first batch begin to send in sink_mode
|
||||||
|
void PrintBeginInfoWhenFirstBatch(const bool &first_push_flag);
|
||||||
|
|
||||||
|
// Name: PrintEndInfoWhenFirstBatch(bool)
|
||||||
|
// Description: Print info when first batch send successful in sink_mode
|
||||||
|
void PrintEndInfoWhenFirstBatch(bool *first_push_flag);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
#ifdef ENABLE_TDTQUE
|
#ifdef ENABLE_TDTQUE
|
||||||
void WaitContinueSignal() const;
|
void WaitContinueSignal() const;
|
||||||
|
|
|
@ -14,9 +14,10 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "minddata/dataset/engine/datasetops/source/generator_op.h"
|
#include "minddata/dataset/engine/datasetops/source/generator_op.h"
|
||||||
#include <iomanip>
|
|
||||||
#include "minddata/dataset/core/global_context.h"
|
|
||||||
|
|
||||||
|
#include <iomanip>
|
||||||
|
|
||||||
|
#include "minddata/dataset/core/global_context.h"
|
||||||
#include "minddata/dataset/engine/execution_tree.h"
|
#include "minddata/dataset/engine/execution_tree.h"
|
||||||
#include "minddata/dataset/util/task_manager.h"
|
#include "minddata/dataset/util/task_manager.h"
|
||||||
|
|
||||||
|
@ -24,13 +25,14 @@ namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
GeneratorOp::GeneratorOp(py::function generator_function, std::vector<std::string> column_names,
|
GeneratorOp::GeneratorOp(py::function generator_function, std::vector<std::string> column_names,
|
||||||
std::vector<DataType> column_types, int32_t prefetch_size, int32_t connector_size,
|
std::vector<DataType> column_types, int32_t prefetch_size, int32_t connector_size,
|
||||||
std::shared_ptr<SamplerRT> sampler)
|
std::shared_ptr<SamplerRT> sampler, uint32_t num_parallel_workers)
|
||||||
: PipelineOp(connector_size, std::move(sampler)),
|
: PipelineOp(connector_size, std::move(sampler)),
|
||||||
generator_function_(generator_function),
|
generator_function_(generator_function),
|
||||||
column_names_(column_names),
|
column_names_(column_names),
|
||||||
column_types_(std::move(column_types)),
|
column_types_(std::move(column_types)),
|
||||||
prefetch_size_(prefetch_size),
|
prefetch_size_(prefetch_size),
|
||||||
generator_counter_(0) {}
|
generator_counter_(0),
|
||||||
|
num_parallel_workers_(num_parallel_workers) {}
|
||||||
|
|
||||||
void GeneratorOp::Print(std::ostream &out, bool show_all) const {
|
void GeneratorOp::Print(std::ostream &out, bool show_all) const {
|
||||||
if (!show_all) {
|
if (!show_all) {
|
||||||
|
@ -174,7 +176,15 @@ Status GeneratorOp::operator()() {
|
||||||
return Status(StatusCode::kMDPythonInterpreterFailure, "Python Interpreter is finalized");
|
return Status(StatusCode::kMDPythonInterpreterFailure, "Python Interpreter is finalized");
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
|
auto start = ProfilingTime::GetCurMilliSecond();
|
||||||
RETURN_IF_NOT_OK(PyRowToTensorRow(generator_.attr("__next__")(), &new_row));
|
RETURN_IF_NOT_OK(PyRowToTensorRow(generator_.attr("__next__")(), &new_row));
|
||||||
|
auto end = ProfilingTime::GetCurMilliSecond();
|
||||||
|
if ((end - start) / num_parallel_workers_ > kGetItemTimeOutMilliSeconds) {
|
||||||
|
MS_LOG(WARNING) << "Bad performance attention, it takes more than 25 seconds to generator.__next__ new row, "
|
||||||
|
"which might cause `GetNext` timeout problem when sink_mode=True. You can increase the "
|
||||||
|
"parameter num_parallel_workers in GeneratorDataset / optimize the efficiency of "
|
||||||
|
"obtaining samples in the user-defined generator function.";
|
||||||
|
}
|
||||||
generator_counter_++;
|
generator_counter_++;
|
||||||
} catch (py::error_already_set &e) {
|
} catch (py::error_already_set &e) {
|
||||||
eoe = e.matches(PyExc_StopIteration);
|
eoe = e.matches(PyExc_StopIteration);
|
||||||
|
|
|
@ -36,11 +36,13 @@ namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
#pragma GCC visibility push(hidden)
|
#pragma GCC visibility push(hidden)
|
||||||
|
|
||||||
|
constexpr int32_t kGetItemTimeOutMilliSeconds = 25000;
|
||||||
|
|
||||||
class GeneratorOp : public PipelineOp, public RandomAccessOp {
|
class GeneratorOp : public PipelineOp, public RandomAccessOp {
|
||||||
public:
|
public:
|
||||||
GeneratorOp(py::function generator_function, std::vector<std::string> column_names,
|
GeneratorOp(py::function generator_function, std::vector<std::string> column_names,
|
||||||
std::vector<DataType> column_types, int32_t prefetch_size, int32_t connector_size,
|
std::vector<DataType> column_types, int32_t prefetch_size, int32_t connector_size,
|
||||||
std::shared_ptr<SamplerRT> sampler);
|
std::shared_ptr<SamplerRT> sampler, uint32_t num_parallel_workers);
|
||||||
|
|
||||||
~GeneratorOp() = default;
|
~GeneratorOp() = default;
|
||||||
|
|
||||||
|
@ -81,6 +83,7 @@ class GeneratorOp : public PipelineOp, public RandomAccessOp {
|
||||||
std::vector<DataType> column_types_;
|
std::vector<DataType> column_types_;
|
||||||
int32_t prefetch_size_;
|
int32_t prefetch_size_;
|
||||||
int64_t generator_counter_;
|
int64_t generator_counter_;
|
||||||
|
uint32_t num_parallel_workers_;
|
||||||
|
|
||||||
py::object generator_;
|
py::object generator_;
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,8 @@
|
||||||
|
|
||||||
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
|
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "minddata/dataset/engine/datasetops/repeat_op.h"
|
#include "minddata/dataset/engine/datasetops/repeat_op.h"
|
||||||
#include "minddata/dataset/engine/datasetops/source/generator_op.h"
|
#include "minddata/dataset/engine/datasetops/source/generator_op.h"
|
||||||
#include "minddata/dataset/engine/opt/pass.h"
|
#include "minddata/dataset/engine/opt/pass.h"
|
||||||
|
@ -26,30 +28,33 @@ namespace dataset {
|
||||||
|
|
||||||
GeneratorNode::GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names,
|
GeneratorNode::GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names,
|
||||||
const std::vector<DataType> &column_types, int64_t source_len,
|
const std::vector<DataType> &column_types, int64_t source_len,
|
||||||
std::shared_ptr<SamplerObj> sampler)
|
std::shared_ptr<SamplerObj> sampler, uint32_t num_parallel_workers)
|
||||||
: MappableSourceNode(),
|
: MappableSourceNode(),
|
||||||
generator_function_(generator_function),
|
generator_function_(generator_function),
|
||||||
column_names_(column_names),
|
column_names_(column_names),
|
||||||
column_types_(column_types),
|
column_types_(column_types),
|
||||||
reset_ancestor_(nullptr),
|
reset_ancestor_(nullptr),
|
||||||
sampler_(std::move(sampler)),
|
sampler_(std::move(sampler)),
|
||||||
source_len_(source_len) {}
|
source_len_(source_len),
|
||||||
|
num_parallel_workers_(num_parallel_workers) {}
|
||||||
|
|
||||||
GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema,
|
GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema,
|
||||||
int64_t source_len, std::shared_ptr<SamplerObj> sampler)
|
int64_t source_len, std::shared_ptr<SamplerObj> sampler, uint32_t num_parallel_workers)
|
||||||
: MappableSourceNode(),
|
: MappableSourceNode(),
|
||||||
generator_function_(generator_function),
|
generator_function_(generator_function),
|
||||||
schema_(schema),
|
schema_(schema),
|
||||||
reset_ancestor_(nullptr),
|
reset_ancestor_(nullptr),
|
||||||
sampler_(std::move(sampler)),
|
sampler_(std::move(sampler)),
|
||||||
source_len_(source_len) {}
|
source_len_(source_len),
|
||||||
|
num_parallel_workers_(num_parallel_workers) {}
|
||||||
|
|
||||||
std::shared_ptr<DatasetNode> GeneratorNode::Copy() {
|
std::shared_ptr<DatasetNode> GeneratorNode::Copy() {
|
||||||
std::shared_ptr<GeneratorNode> node;
|
std::shared_ptr<GeneratorNode> node;
|
||||||
if (schema_ == nullptr) {
|
if (schema_ == nullptr) {
|
||||||
node = std::make_shared<GeneratorNode>(generator_function_, column_names_, column_types_, source_len_, sampler_);
|
node = std::make_shared<GeneratorNode>(generator_function_, column_names_, column_types_, source_len_, sampler_,
|
||||||
|
num_parallel_workers_);
|
||||||
} else {
|
} else {
|
||||||
node = std::make_shared<GeneratorNode>(generator_function_, schema_, source_len_, sampler_);
|
node = std::make_shared<GeneratorNode>(generator_function_, schema_, source_len_, sampler_, num_parallel_workers_);
|
||||||
}
|
}
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
@ -78,8 +83,8 @@ Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_
|
||||||
|
|
||||||
// GeneratorOp's constructor takes in a prefetch_size, which isn't being set by user nor is it being used by
|
// GeneratorOp's constructor takes in a prefetch_size, which isn't being set by user nor is it being used by
|
||||||
// GeneratorOp internally. Here it is given a zero which is the default in generator builder
|
// GeneratorOp internally. Here it is given a zero which is the default in generator builder
|
||||||
std::shared_ptr<GeneratorOp> op = std::make_shared<GeneratorOp>(generator_function_, column_names_, column_types_, 0,
|
std::shared_ptr<GeneratorOp> op = std::make_shared<GeneratorOp>(
|
||||||
connector_que_size_, sampler_rt);
|
generator_function_, column_names_, column_types_, 0, connector_que_size_, sampler_rt, num_parallel_workers_);
|
||||||
// set the number of rows from source length
|
// set the number of rows from source length
|
||||||
op->SetNumRows(source_len_);
|
op->SetNumRows(source_len_);
|
||||||
|
|
||||||
|
|
|
@ -34,11 +34,12 @@ class GeneratorNode : public MappableSourceNode {
|
||||||
public:
|
public:
|
||||||
/// \brief Constructor
|
/// \brief Constructor
|
||||||
GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names,
|
GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names,
|
||||||
const std::vector<DataType> &column_types, int64_t source_len, std::shared_ptr<SamplerObj> sampler);
|
const std::vector<DataType> &column_types, int64_t source_len, std::shared_ptr<SamplerObj> sampler,
|
||||||
|
uint32_t num_parallel_workers);
|
||||||
|
|
||||||
/// \brief Constructor
|
/// \brief Constructor
|
||||||
GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema, int64_t source_len,
|
GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema, int64_t source_len,
|
||||||
std::shared_ptr<SamplerObj> sampler);
|
std::shared_ptr<SamplerObj> sampler, uint32_t num_parallel_workers);
|
||||||
|
|
||||||
/// \brief Destructor
|
/// \brief Destructor
|
||||||
~GeneratorNode() = default;
|
~GeneratorNode() = default;
|
||||||
|
@ -107,6 +108,7 @@ class GeneratorNode : public MappableSourceNode {
|
||||||
std::shared_ptr<SchemaObj> schema_;
|
std::shared_ptr<SchemaObj> schema_;
|
||||||
std::shared_ptr<RepeatNode> reset_ancestor_; // updated its immediate Repeat/EpochCtrl ancestor in GeneratorNodePass
|
std::shared_ptr<RepeatNode> reset_ancestor_; // updated its immediate Repeat/EpochCtrl ancestor in GeneratorNodePass
|
||||||
std::shared_ptr<SamplerObj> sampler_;
|
std::shared_ptr<SamplerObj> sampler_;
|
||||||
|
uint32_t num_parallel_workers_;
|
||||||
int64_t source_len_; // Length of the dataset source provided by the user, -1 means it's unknown
|
int64_t source_len_; // Length of the dataset source provided by the user, -1 means it's unknown
|
||||||
|
|
||||||
/// \brief Base-class override for accepting IRNodePass visitor
|
/// \brief Base-class override for accepting IRNodePass visitor
|
||||||
|
|
|
@ -4326,11 +4326,12 @@ class GeneratorDataset(MappableDataset):
|
||||||
def parse(self, children=None):
|
def parse(self, children=None):
|
||||||
if self.schema is None:
|
if self.schema is None:
|
||||||
return cde.GeneratorNode(self.prepared_source, self.column_names, self.column_types, self.source_len,
|
return cde.GeneratorNode(self.prepared_source, self.column_names, self.column_types, self.source_len,
|
||||||
self.sampler)
|
self.sampler, self.num_parallel_workers)
|
||||||
schema = self.schema
|
schema = self.schema
|
||||||
if isinstance(schema, Schema):
|
if isinstance(schema, Schema):
|
||||||
schema = self.schema.cpp_schema
|
schema = self.schema.cpp_schema
|
||||||
return cde.GeneratorNode(self.prepared_source, schema, self.source_len, self.sampler)
|
return cde.GeneratorNode(self.prepared_source, schema, self.source_len, self.sampler,
|
||||||
|
self.num_parallel_workers)
|
||||||
|
|
||||||
|
|
||||||
class TFRecordDataset(SourceDataset):
|
class TFRecordDataset(SourceDataset):
|
||||||
|
|
|
@ -380,6 +380,7 @@ class Model:
|
||||||
epoch_num=epoch_num,
|
epoch_num=epoch_num,
|
||||||
dataset_helper=dataset_helper)
|
dataset_helper=dataset_helper)
|
||||||
train_dataset._dataset_helper = dataset_helper
|
train_dataset._dataset_helper = dataset_helper
|
||||||
|
train_dataset._warmup_epoch = epoch
|
||||||
|
|
||||||
def _init(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1):
|
def _init(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1):
|
||||||
"""
|
"""
|
||||||
|
@ -701,6 +702,12 @@ class Model:
|
||||||
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
||||||
if isinstance(self._train_network, nn.GraphCell) and dataset_sink_mode is True:
|
if isinstance(self._train_network, nn.GraphCell) and dataset_sink_mode is True:
|
||||||
raise ValueError("Dataset sink mode is currently not supported when training with a GraphCell.")
|
raise ValueError("Dataset sink mode is currently not supported when training with a GraphCell.")
|
||||||
|
|
||||||
|
if hasattr(train_dataset, '_warmup_epoch') and train_dataset._warmup_epoch != epoch:
|
||||||
|
raise ValueError("Use Model.build to initialize model, but the value of parameter `epoch` in Model.build "
|
||||||
|
"is not equal to value in Model.train, got {} and {} separately."
|
||||||
|
.format(train_dataset._warmup_epoch, epoch))
|
||||||
|
|
||||||
Validator.check_is_int(sink_size)
|
Validator.check_is_int(sink_size)
|
||||||
dataset_size = train_dataset.get_dataset_size()
|
dataset_size = train_dataset.get_dataset_size()
|
||||||
if dataset_size == 0:
|
if dataset_size == 0:
|
||||||
|
|
Loading…
Reference in New Issue