forked from mindspore-Ecosystem/mindspore
add callback to map's c api
fix test err fix ci round 1 fix ci round 2 add col_order to batch_node minor fix
This commit is contained in:
parent
a9a9c45662
commit
7538f82da9
|
@ -499,9 +499,10 @@ FilterDataset::FilterDataset(std::shared_ptr<Dataset> input, std::function<Tenso
|
|||
|
||||
MapDataset::MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_ptr<TensorOperation>> operations,
|
||||
std::vector<std::string> input_columns, std::vector<std::string> output_columns,
|
||||
const std::vector<std::string> &project_columns, const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto ds =
|
||||
std::make_shared<MapNode>(input->IRNode(), operations, input_columns, output_columns, project_columns, cache);
|
||||
const std::vector<std::string> &project_columns, const std::shared_ptr<DatasetCache> &cache,
|
||||
std::vector<std::shared_ptr<DSCallback>> callbacks) {
|
||||
auto ds = std::make_shared<MapNode>(input->IRNode(), operations, input_columns, output_columns, project_columns,
|
||||
cache, callbacks);
|
||||
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
|
|
@ -44,7 +44,12 @@ bool Iterator::GetNextRow(TensorVec *row) {
|
|||
}
|
||||
|
||||
// Shut down the data pipeline.
|
||||
void Iterator::Stop() { runtime_context_->Terminate(); }
|
||||
void Iterator::Stop() {
|
||||
Status rc = runtime_context_->Terminate();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << rc.ToString();
|
||||
}
|
||||
}
|
||||
|
||||
// Function to build and launch the execution tree.
|
||||
Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
|
||||
|
|
|
@ -385,6 +385,9 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
/// \return Status
|
||||
virtual Status WaitForWorkers() { return Status::OK(); }
|
||||
|
||||
/// \brief Add callback to DatasetOp, only MapOp supports Callback at the moment
|
||||
void AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks) { callback_manager_.AddCallbacks(callbacks); }
|
||||
|
||||
protected:
|
||||
/// \brief Removes a parent operator from this operator
|
||||
/// \notes External callers do not have access to this function
|
||||
|
|
|
@ -13,9 +13,9 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
|
@ -26,8 +26,6 @@
|
|||
#include "minddata/dataset/engine/data_buffer.h"
|
||||
#include "minddata/dataset/engine/datasetops/map_op/cpu_map_job.h"
|
||||
#include "minddata/dataset/engine/datasetops/map_op/gpu_map_job.h"
|
||||
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/task_manager.h"
|
||||
|
@ -60,7 +58,7 @@ Status MapOp::Builder::Build(std::shared_ptr<MapOp> *ptr) {
|
|||
RETURN_IF_NOT_OK(sanityCheck());
|
||||
*ptr = std::make_shared<MapOp>(std::move(build_in_col_names_), std::move(build_out_col_names_),
|
||||
std::move(build_tensor_funcs_), build_num_workers_, build_op_connector_size_);
|
||||
(*ptr)->callback_manager_.AddCallbacks(std::move(builder_callbacks_));
|
||||
(*ptr)->AddCallbacks(std::move(builder_callbacks_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -31,13 +31,15 @@ namespace dataset {
|
|||
// constructor #1, called by Pybind
|
||||
BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder, bool pad,
|
||||
const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names,
|
||||
py::function batch_size_func, py::function batch_map_func,
|
||||
const std::vector<std::string> &col_order, py::function batch_size_func,
|
||||
py::function batch_map_func,
|
||||
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map)
|
||||
: batch_size_(batch_size),
|
||||
drop_remainder_(drop_remainder),
|
||||
pad_(pad),
|
||||
in_col_names_(in_col_names),
|
||||
out_col_names_(out_col_names),
|
||||
col_order_(col_order),
|
||||
batch_size_func_(batch_size_func),
|
||||
batch_map_func_(batch_map_func),
|
||||
pad_map_(pad_map) {
|
||||
|
@ -83,8 +85,8 @@ std::vector<std::shared_ptr<DatasetOp>> BatchNode::Build() {
|
|||
in_col_names_, out_col_names_, batch_size_func_, batch_map_func_,
|
||||
pad_map_));
|
||||
// need to insert a project when per_batch_func changes the number of columns
|
||||
if (!out_col_names_.empty()) {
|
||||
auto project_op = std::make_shared<ProjectOp>(out_col_names_);
|
||||
if (!col_order_.empty()) {
|
||||
auto project_op = std::make_shared<ProjectOp>(col_order_);
|
||||
node_ops.push_back(project_op);
|
||||
}
|
||||
#else
|
||||
|
|
|
@ -34,7 +34,7 @@ class BatchNode : public DatasetNode {
|
|||
/// \brief Constructor #1, for Python API to create a BatchNode
|
||||
BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder, bool pad,
|
||||
const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names,
|
||||
py::function batch_size_func, py::function batch_map_func,
|
||||
const std::vector<std::string> &col_order, py::function batch_size_func, py::function batch_map_func,
|
||||
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map);
|
||||
#endif
|
||||
|
||||
|
@ -58,6 +58,7 @@ class BatchNode : public DatasetNode {
|
|||
bool pad_;
|
||||
std::vector<std::string> in_col_names_;
|
||||
std::vector<std::string> out_col_names_;
|
||||
std::vector<std::string> col_order_;
|
||||
#ifdef ENABLE_PYTHON
|
||||
py::function batch_size_func_;
|
||||
py::function batch_map_func_;
|
||||
|
|
|
@ -29,12 +29,14 @@ namespace dataset {
|
|||
|
||||
MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations,
|
||||
std::vector<std::string> input_columns, std::vector<std::string> output_columns,
|
||||
const std::vector<std::string> &project_columns, std::shared_ptr<DatasetCache> cache)
|
||||
const std::vector<std::string> &project_columns, std::shared_ptr<DatasetCache> cache,
|
||||
std::vector<std::shared_ptr<DSCallback>> callbacks)
|
||||
: operations_(operations),
|
||||
input_columns_(input_columns),
|
||||
output_columns_(output_columns),
|
||||
project_columns_(project_columns),
|
||||
DatasetNode(std::move(cache)) {
|
||||
DatasetNode(std::move(cache)),
|
||||
callbacks_(callbacks) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
|
@ -53,6 +55,11 @@ std::vector<std::shared_ptr<DatasetOp>> MapNode::Build() {
|
|||
// This parameter will be removed with next rebase
|
||||
std::vector<std::string> col_orders;
|
||||
auto map_op = std::make_shared<MapOp>(input_columns_, output_columns_, tensor_ops, num_workers_, connector_que_size_);
|
||||
|
||||
if (!callbacks_.empty()) {
|
||||
map_op->AddCallbacks(callbacks_);
|
||||
}
|
||||
|
||||
if (!project_columns_.empty()) {
|
||||
auto project_op = std::make_shared<ProjectOp>(project_columns_);
|
||||
node_ops.push_back(project_op);
|
||||
|
|
|
@ -31,7 +31,8 @@ class MapNode : public DatasetNode {
|
|||
/// \brief Constructor
|
||||
MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations,
|
||||
std::vector<std::string> input_columns = {}, std::vector<std::string> output_columns = {},
|
||||
const std::vector<std::string> &columns = {}, std::shared_ptr<DatasetCache> cache = nullptr);
|
||||
const std::vector<std::string> &columns = {}, std::shared_ptr<DatasetCache> cache = nullptr,
|
||||
std::vector<std::shared_ptr<DSCallback>> callbacks = {});
|
||||
|
||||
/// \brief Destructor
|
||||
~MapNode() = default;
|
||||
|
@ -49,6 +50,7 @@ class MapNode : public DatasetNode {
|
|||
std::vector<std::string> input_columns_;
|
||||
std::vector<std::string> output_columns_;
|
||||
std::vector<std::string> project_columns_;
|
||||
std::vector<std::shared_ptr<DSCallback>> callbacks_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -149,7 +149,7 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
|
|||
// We finish the walk of this RepeatOp's descendent nodes.
|
||||
// The total repeats of nodes above this Repeat(n) have nothing to do with this RepeatOp's parameter n.
|
||||
// But num_repeats_ has been multiplied by n during this Repeat(n)'s PreRunOnNode,
|
||||
// so we devide num_repeats_ by n to be able to correctly set total repeats for nodes above this RepeatOp.
|
||||
// so we divide num_repeats_ by n to be able to correctly set total repeats for nodes above this RepeatOp.
|
||||
num_repeats_ /= node->num_repeats();
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -120,7 +120,7 @@ Status CacheTransformPass::CachePass::NonMappableCacheLeafSetup(std::shared_ptr<
|
|||
RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache.");
|
||||
}
|
||||
|
||||
// Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf
|
||||
// Sampler for non mappable dataset only works if there is a downstream cache. Remove it from the leaf
|
||||
// as save it for use by cache op in ascendant tree.
|
||||
if (is_caching_) {
|
||||
RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_));
|
||||
|
@ -261,7 +261,8 @@ Status CacheTransformPass::RunOnTree(ExecutionTree *tree, bool *modified) {
|
|||
// Then, execute the transform for each pair
|
||||
for (auto cache_pair : cache_pass.cache_pairs()) {
|
||||
MS_LOG(DEBUG) << "Cache transform pass: Executing a cache op mappable transform.";
|
||||
ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client());
|
||||
RETURN_IF_NOT_OK(
|
||||
ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client()));
|
||||
}
|
||||
MS_LOG(INFO) << "Pre pass: Cache transform pass complete.";
|
||||
return Status::OK();
|
||||
|
|
|
@ -60,95 +60,95 @@ class CacheTransformPass : public TreePass {
|
|||
|
||||
#ifndef ENABLE_ANDROID
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) override;
|
||||
#endif
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override;
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override;
|
||||
#endif
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
|
|
|
@ -276,9 +276,10 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
std::vector<std::string> input_columns = {},
|
||||
std::vector<std::string> output_columns = {},
|
||||
const std::vector<std::string> &project_columns = {},
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr,
|
||||
std::vector<std::shared_ptr<DSCallback>> callbacks = {}) {
|
||||
return std::make_shared<MapDataset>(shared_from_this(), operations, input_columns, output_columns, project_columns,
|
||||
cache);
|
||||
cache, callbacks);
|
||||
}
|
||||
|
||||
/// \brief Function to create a Project Dataset
|
||||
|
@ -443,7 +444,8 @@ class MapDataset : public Dataset {
|
|||
public:
|
||||
MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_ptr<TensorOperation>> operations,
|
||||
std::vector<std::string> input_columns, std::vector<std::string> output_columns,
|
||||
const std::vector<std::string> &project_columns, const std::shared_ptr<DatasetCache> &cache);
|
||||
const std::vector<std::string> &project_columns, const std::shared_ptr<DatasetCache> &cache,
|
||||
std::vector<std::shared_ptr<DSCallback>> callbacks);
|
||||
};
|
||||
|
||||
class ProjectDataset : public Dataset {
|
||||
|
|
|
@ -21,6 +21,8 @@
|
|||
#include "minddata/dataset/callback/ds_callback.h"
|
||||
#include "minddata/dataset/core/client.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
#include "minddata/dataset/include/transforms.h"
|
||||
#include "minddata/dataset/kernels/data/no_op.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
|
@ -149,7 +151,7 @@ TEST_F(MindDataTestCallback, TestBasicCallback) {
|
|||
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
||||
TensorShape shape({}); // empty shape is a 1-value scalar Tensor
|
||||
ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape);
|
||||
schema->AddColumn(col);
|
||||
ASSERT_OK(schema->AddColumn(col));
|
||||
std::shared_ptr<RandomDataOp> leaf;
|
||||
rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(44).Build(&leaf);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
@ -196,7 +198,7 @@ TEST_F(MindDataTestCallback, TestMutiEpochCallback) {
|
|||
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
||||
TensorShape shape({}); // empty shape is a 1-value scalar Tensor
|
||||
ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape);
|
||||
schema->AddColumn(col);
|
||||
ASSERT_OK(schema->AddColumn(col));
|
||||
std::shared_ptr<RandomDataOp> leaf;
|
||||
rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(4).Build(&leaf);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
@ -253,7 +255,7 @@ TEST_F(MindDataTestCallback, TestSelectedCallback) {
|
|||
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
||||
TensorShape shape({}); // empty shape is a 1-value scalar Tensor
|
||||
ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape);
|
||||
schema->AddColumn(col);
|
||||
ASSERT_OK(schema->AddColumn(col));
|
||||
std::shared_ptr<RandomDataOp> leaf;
|
||||
rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(4).Build(&leaf);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
@ -296,3 +298,34 @@ TEST_F(MindDataTestCallback, TestSelectedCallback) {
|
|||
EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs);
|
||||
EXPECT_EQ(tst_cb->all_step_nums(len), all_steps);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestCallback, TestCAPICallback) {
|
||||
MS_LOG(INFO) << "Doing: MindDataTestCallback-TestCAPICallback";
|
||||
// config callback
|
||||
std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(64);
|
||||
std::shared_ptr<DSCallback> cb1 = tst_cb;
|
||||
// config leaf_op, use random_data to avoid I/O
|
||||
std::shared_ptr<SchemaObj> schema = std::make_shared<SchemaObj>();
|
||||
ASSERT_TRUE(schema->add_column("label", "uint32", {}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(44, schema);
|
||||
ds = ds->Map({transforms::TypeCast("uint64")}, {"label"}, {}, {}, nullptr, {cb1});
|
||||
ds = ds->Repeat(2);
|
||||
|
||||
TreeAdapter tree_adapter;
|
||||
// using tree_adapter to set num_epoch = 1
|
||||
ASSERT_OK(tree_adapter.Compile(ds->IRNode(), 1));
|
||||
|
||||
TensorRow row;
|
||||
ASSERT_OK(tree_adapter.GetNext(&row));
|
||||
while (!row.empty()) {
|
||||
ASSERT_OK(tree_adapter.GetNext(&row));
|
||||
}
|
||||
std::vector<std::string> callback_names = {"BGN", "EPBGN", "SPBGN", "SPEND", "SPBGN", "SPEND", "EPEND"};
|
||||
std::vector<int64_t> all_steps = {0, 0, 1, 1, 65, 65, 88};
|
||||
std::vector<int64_t> all_epochs = {0, 1, 1, 1, 1, 1, 1};
|
||||
// doing resize to make sure no unexpected epoch_end or extra epoch_begin is called
|
||||
size_t len = 7;
|
||||
EXPECT_EQ(tst_cb->all_names(len), callback_names);
|
||||
EXPECT_EQ(tst_cb->all_step_nums(len), all_steps);
|
||||
EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue