forked from mindspore-Ecosystem/mindspore
Merge branch 'master' of gitee.com:mindspore/mindspore
This commit is contained in:
commit
e5fafc5e00
|
@ -16,6 +16,7 @@
|
|||
#include <iostream>
|
||||
#include <utility>
|
||||
|
||||
#include "dataset/core/config_manager.h"
|
||||
#include "dataset/engine/data_buffer.h"
|
||||
#include "dataset/engine/datasetops/skip_op.h"
|
||||
#include "dataset/engine/db_connector.h"
|
||||
|
@ -26,7 +27,10 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Builder constructor. Creates the builder object.
|
||||
SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) {}
|
||||
SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) {
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
builder_op_connector_size_ = cfg->op_connector_size();
|
||||
}
|
||||
|
||||
Status SkipOp::Builder::SanityCheck() const {
|
||||
if (build_max_skips_ < 0) {
|
||||
|
@ -39,12 +43,13 @@ Status SkipOp::Builder::SanityCheck() const {
|
|||
// The builder "build" method creates the final object.
|
||||
Status SkipOp::Builder::Build(std::shared_ptr<SkipOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
*ptr = std::make_shared<SkipOp>(build_max_skips_);
|
||||
*ptr = std::make_shared<SkipOp>(build_max_skips_, builder_op_connector_size_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Constructor of the SkipOp.
|
||||
SkipOp::SkipOp(int32_t count) : PipelineOp(0), max_skips_(count), skip_count_(0) {}
|
||||
SkipOp::SkipOp(int32_t count, int32_t op_connector_size)
|
||||
: PipelineOp(op_connector_size), max_skips_(count), skip_count_(0) {}
|
||||
|
||||
// Destructor
|
||||
SkipOp::~SkipOp() {}
|
||||
|
@ -59,49 +64,6 @@ void SkipOp::Print(std::ostream &out, bool show_all) const {
|
|||
<< "\nCurrent skip count: " << skip_count_ << "\nMax skip count: " << max_skips_;
|
||||
}
|
||||
|
||||
// Since the buffer may contain multi rows, this function will drop the rows
|
||||
// that need to skip in it, and then return the buffer.
|
||||
Status SkipOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) {
|
||||
if (child_.empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("SkipOp can't be the leaf node.");
|
||||
}
|
||||
|
||||
std::unique_ptr<DataBuffer> buf;
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
|
||||
|
||||
// Drop first max_skips_ rows
|
||||
while (skip_count_ < max_skips_) {
|
||||
if (buf->eoe() || buf->eof()) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Consider the rows of buffer more than 1
|
||||
TensorRow drop_row;
|
||||
int row_num = buf->NumRows();
|
||||
int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_;
|
||||
skip_count_ += drop_num;
|
||||
for (int i = 0; i < drop_num; i++) {
|
||||
RETURN_IF_NOT_OK(buf->PopRow(&drop_row));
|
||||
}
|
||||
if (buf->NumRows() == 0) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true));
|
||||
}
|
||||
}
|
||||
|
||||
// Handling eoe
|
||||
if (buf->eoe()) {
|
||||
RETURN_IF_NOT_OK(EoeReceived(worker_id));
|
||||
}
|
||||
|
||||
// Handling eof
|
||||
if (buf->eof()) {
|
||||
RETURN_IF_NOT_OK(EofReceived(worker_id));
|
||||
}
|
||||
|
||||
*p_buffer = std::move(buf);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Base-class override for handling cases when an eoe is received.
|
||||
Status SkipOp::EoeReceived(int32_t worker_id) {
|
||||
skip_count_ = 0;
|
||||
|
@ -109,13 +71,45 @@ Status SkipOp::EoeReceived(int32_t worker_id) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Class functor operator () override.
|
||||
// Most dataset ops operate by launching a thread (see ExecutionTree).
|
||||
// However, the SkipOp is defined as a inlined operator, so it is invalid to
|
||||
// launch the functor since this op runs inlined inside another operator. The
|
||||
// function is overloaded to ensure that it is not called by mistake (it will
|
||||
// generate an error).
|
||||
Status SkipOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is an inlined operator."); }
|
||||
// main entry point for skip
|
||||
Status SkipOp::operator()() {
|
||||
TaskManager::FindMe()->Post();
|
||||
std::unique_ptr<DataBuffer> curr_buffer;
|
||||
RETURN_IF_NOT_OK(GetNextInput(&curr_buffer));
|
||||
while (curr_buffer->eof() == false) {
|
||||
// Reset count
|
||||
skip_count_ = 0;
|
||||
while (curr_buffer->eoe() == false) {
|
||||
// Drop first count rows
|
||||
while (skip_count_ < max_skips_) {
|
||||
if (curr_buffer->eoe() || curr_buffer->eof()) {
|
||||
break;
|
||||
}
|
||||
// Consider the rows of buffer more than one
|
||||
TensorRow drop_row;
|
||||
int row_num = curr_buffer->NumRows();
|
||||
int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_;
|
||||
skip_count_ += drop_num;
|
||||
for (int i = 0; i < drop_num; i++) {
|
||||
RETURN_IF_NOT_OK(curr_buffer->PopRow(&drop_row));
|
||||
}
|
||||
if (curr_buffer->NumRows() == 0) {
|
||||
RETURN_IF_NOT_OK(GetNextInput(&curr_buffer));
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer)));
|
||||
RETURN_IF_NOT_OK(GetNextInput(&curr_buffer));
|
||||
}
|
||||
// we got eoe, now try again until we got eof
|
||||
MS_LOG(DEBUG) << "Skip operator EOE Received.";
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))));
|
||||
RETURN_IF_NOT_OK(GetNextInput(&curr_buffer));
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "Skip operator EOF Received.";
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF))));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Base-class override for handling cases when an eof is received.
|
||||
Status SkipOp::EofReceived(int32_t worker_id) {
|
||||
|
|
|
@ -42,6 +42,7 @@ class SkipOp : public PipelineOp {
|
|||
|
||||
private:
|
||||
int32_t build_max_skips_;
|
||||
int32_t builder_op_connector_size_;
|
||||
|
||||
Status SanityCheck() const;
|
||||
};
|
||||
|
@ -49,7 +50,7 @@ class SkipOp : public PipelineOp {
|
|||
// Constructor of the SkipOp.
|
||||
// @note The builder class should be used to call it
|
||||
// @param count - The number of skips to do
|
||||
explicit SkipOp(int32_t count);
|
||||
explicit SkipOp(int32_t count, int32_t op_connector_size);
|
||||
|
||||
// Destructor
|
||||
~SkipOp();
|
||||
|
@ -60,23 +61,11 @@ class SkipOp : public PipelineOp {
|
|||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
// Class functor operator () override.
|
||||
// Most dataset ops operate by launching a thread (see ExecutionTree).
|
||||
// However, the SkipOp is defined as a inlined operator, so it is invalid to launch the
|
||||
// functor since this op runs inlined inside another operator. The function is overloaded to
|
||||
// ensure that it is not called by mistake (it will generate an error).
|
||||
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
|
||||
// provide the master loop that drives the logic for performing the work
|
||||
// @return Status - The error code return
|
||||
Status operator()() override;
|
||||
|
||||
// This function returns the buffer that is at the top of our output connector. The caller is
|
||||
// typically our parent node, when the parent is asking us to provide the next buffer of data.
|
||||
// Since SkipOp is an inlined op, getting a buffer from us will simply bounce you to get
|
||||
// a buffer from our child.
|
||||
// @param p_buffer - output pointer to the buffer that it will fetch.
|
||||
// @param worker_id - The worker id
|
||||
// @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE.
|
||||
// @return Status - The error code return
|
||||
Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override;
|
||||
|
||||
// Base-class override for handling cases when an eoe is received.
|
||||
// @param worker_id - The worker id
|
||||
Status EoeReceived(int32_t worker_id) override;
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
from abc import abstractmethod
|
||||
import copy
|
||||
import weakref
|
||||
from importlib import import_module
|
||||
|
||||
from mindspore._c_dataengine import DEPipeline
|
||||
from mindspore._c_dataengine import OpName
|
||||
|
@ -24,14 +25,29 @@ from mindspore._c_dataengine import OpName
|
|||
from mindspore import log as logger
|
||||
from . import datasets as de
|
||||
|
||||
try:
|
||||
context = import_module("mindspore.context")
|
||||
except ModuleNotFoundError:
|
||||
context = None
|
||||
|
||||
ITERATORS_LIST = list()
|
||||
|
||||
|
||||
def _cleanup():
|
||||
"""Release all the Iterator."""
|
||||
for itr_ref in ITERATORS_LIST:
|
||||
itr = itr_ref()
|
||||
if itr is not None:
|
||||
itr.release()
|
||||
if context:
|
||||
device_type = context.get_context("device_target")
|
||||
if device_type == "GPU":
|
||||
itr_ref.release()
|
||||
else:
|
||||
itr = itr_ref()
|
||||
if itr is not None:
|
||||
itr.release()
|
||||
else:
|
||||
itr = itr_ref()
|
||||
if itr is not None:
|
||||
itr.release()
|
||||
|
||||
|
||||
def alter_tree(node):
|
||||
|
@ -85,7 +101,14 @@ class Iterator:
|
|||
"""
|
||||
|
||||
def __init__(self, dataset):
|
||||
ITERATORS_LIST.append(weakref.ref(self))
|
||||
if context:
|
||||
device_type = context.get_context("device_target")
|
||||
if device_type == "GPU":
|
||||
ITERATORS_LIST.append(self)
|
||||
else:
|
||||
ITERATORS_LIST.append(weakref.ref(self))
|
||||
else:
|
||||
ITERATORS_LIST.append(weakref.ref(self))
|
||||
# create a copy of tree and work on it.
|
||||
self.dataset = copy.deepcopy(dataset)
|
||||
self.dataset = alter_tree(self.dataset)
|
||||
|
|
|
@ -76,8 +76,13 @@ class BoundingBoxEncode(PrimitiveWithInfer):
|
|||
Tensor, encoded bounding boxes.
|
||||
|
||||
Examples:
|
||||
>>> anchor_box = Tensor([[4,1,2,1],[2,2,2,3]],mindspore.float32)
|
||||
>>> groundtruth_box = Tensor([[3,1,2,2],[1,2,1,4]],mindspore.float32)
|
||||
>>> boundingbox_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0))
|
||||
>>> delta_box = boundingbox_encode(anchor_box, groundtruth_box)
|
||||
>>> boundingbox_encode(anchor_box, groundtruth_box)
|
||||
[[5.0000000e-01 5.0000000e-01 -6.5504000e+04 6.9335938e-01]
|
||||
[-1.0000000e+00 2.5000000e-01 0.0000000e+00 4.0551758e-01]]
|
||||
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -118,9 +123,14 @@ class BoundingBoxDecode(PrimitiveWithInfer):
|
|||
Tensor, decoded boxes.
|
||||
|
||||
Examples:
|
||||
>>> anchor_box = Tensor([[4,1,2,1],[2,2,2,3]],mindspore.float32)
|
||||
>>> deltas = Tensor([[3,1,2,2],[1,2,1,4]],mindspore.float32)
|
||||
>>> boundingbox_decode = P.BoundingBoxDecode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0),
|
||||
>>> max_shape=(768, 1280), wh_ratio_clip=0.016)
|
||||
>>> bbox = boundingbox_decode(anchor_box, deltas)
|
||||
>>> boundingbox_decode(anchor_box, deltas)
|
||||
[[4.1953125 0. 0. 5.1953125]
|
||||
[2.140625 0. 3.859375 60.59375]]
|
||||
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
|
|
@ -0,0 +1,198 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.train.model import Model
|
||||
from mindspore import context
|
||||
import mindspore.common.dtype as mstype
|
||||
import os
|
||||
import numpy as np
|
||||
import mindspore.ops.functional as F
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
from resnet import resnet50
|
||||
import random
|
||||
import time
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
ds.config.set_seed(1)
|
||||
|
||||
data_home = "/home/workspace/mindspore_dataset"
|
||||
|
||||
|
||||
def create_dataset(repeat_num=1, training=True, batch_size=32):
|
||||
data_dir = data_home + "/cifar-10-batches-bin"
|
||||
if not training:
|
||||
data_dir = data_home + "/cifar-10-verify-bin"
|
||||
data_set = ds.Cifar10Dataset(data_dir)
|
||||
|
||||
resize_height = 224
|
||||
resize_width = 224
|
||||
rescale = 1.0 / 255.0
|
||||
shift = 0.0
|
||||
|
||||
# define map operations
|
||||
random_crop_op = vision.RandomCrop(
|
||||
(32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
|
||||
random_horizontal_op = vision.RandomHorizontalFlip()
|
||||
# interpolation default BILINEAR
|
||||
resize_op = vision.Resize((resize_height, resize_width))
|
||||
rescale_op = vision.Rescale(rescale, shift)
|
||||
normalize_op = vision.Normalize(
|
||||
(0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023))
|
||||
changeswap_op = vision.HWC2CHW()
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
|
||||
c_trans = []
|
||||
if training:
|
||||
c_trans = [random_crop_op, random_horizontal_op]
|
||||
c_trans += [resize_op, rescale_op, normalize_op,
|
||||
changeswap_op]
|
||||
|
||||
# apply map operations on images
|
||||
data_set = data_set.map(input_columns="label", operations=type_cast_op)
|
||||
data_set = data_set.map(input_columns="image", operations=c_trans)
|
||||
|
||||
# apply shuffle operations
|
||||
data_set = data_set.shuffle(buffer_size=1000)
|
||||
|
||||
# apply batch operations
|
||||
data_set = data_set.batch(batch_size=batch_size, drop_remainder=True)
|
||||
|
||||
# apply repeat operations
|
||||
data_set = data_set.repeat(repeat_num)
|
||||
|
||||
return data_set
|
||||
|
||||
|
||||
class CrossEntropyLoss(nn.Cell):
|
||||
def __init__(self):
|
||||
super(CrossEntropyLoss, self).__init__()
|
||||
self.cross_entropy = P.SoftmaxCrossEntropyWithLogits()
|
||||
self.mean = P.ReduceMean()
|
||||
self.one_hot = P.OneHot()
|
||||
self.one = Tensor(1.0, mstype.float32)
|
||||
self.zero = Tensor(0.0, mstype.float32)
|
||||
|
||||
def construct(self, logits, label):
|
||||
label = self.one_hot(label, F.shape(logits)[1], self.one, self.zero)
|
||||
loss = self.cross_entropy(logits, label)[0]
|
||||
loss = self.mean(loss, (-1,))
|
||||
return loss
|
||||
|
||||
|
||||
class LossGet(Callback):
|
||||
def __init__(self, per_print_times=1):
|
||||
super(LossGet, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
self._loss = 0.0
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
loss = cb_params.net_outputs
|
||||
|
||||
if isinstance(loss, (tuple, list)):
|
||||
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
|
||||
loss = loss[0]
|
||||
|
||||
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
|
||||
loss = np.mean(loss.asnumpy())
|
||||
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
|
||||
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
|
||||
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training."
|
||||
.format(cb_params.cur_epoch_num, cur_step_in_epoch))
|
||||
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
|
||||
self._loss = loss
|
||||
print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss))
|
||||
|
||||
def get_loss(self):
|
||||
return self._loss
|
||||
|
||||
|
||||
def train_process(device_id, epoch_size, num_classes, device_num, batch_size):
|
||||
os.system("mkdir " + str(device_id))
|
||||
os.chdir(str(device_id))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
context.set_context(enable_task_sink=True, device_id=device_id)
|
||||
context.set_context(enable_loop_sink=True)
|
||||
context.set_context(enable_mem_reuse=True)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = resnet50(batch_size, num_classes)
|
||||
loss = CrossEntropyLoss()
|
||||
opt = Momentum(filter(lambda x: x.requires_grad,
|
||||
net.get_parameters()), 0.01, 0.9)
|
||||
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
||||
|
||||
dataset = create_dataset(epoch_size, training=True, batch_size=batch_size)
|
||||
batch_num = dataset.get_dataset_size()
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=1)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10_device_id_" + str(device_id), directory="./",
|
||||
config=config_ck)
|
||||
loss_cb = LossGet()
|
||||
model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb])
|
||||
|
||||
|
||||
def eval(batch_size, num_classes):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
context.set_context(enable_task_sink=True, device_id=0)
|
||||
context.set_context(enable_loop_sink=True)
|
||||
context.set_context(enable_mem_reuse=True)
|
||||
|
||||
net = resnet50(batch_size, num_classes)
|
||||
loss = CrossEntropyLoss()
|
||||
opt = Momentum(filter(lambda x: x.requires_grad,
|
||||
net.get_parameters()), 0.01, 0.9)
|
||||
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
||||
checkpoint_path = "./train_resnet_cifar10_device_id_0-1_1562.ckpt"
|
||||
param_dict = load_checkpoint(checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
eval_dataset = create_dataset(1, training=False)
|
||||
res = model.eval(eval_dataset)
|
||||
print("result: ", res)
|
||||
return res
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_resnet_cifar_1p():
|
||||
device_num = 1
|
||||
epoch_size = 1
|
||||
num_classes = 10
|
||||
batch_size = 32
|
||||
device_id = 0
|
||||
train_process(device_id, epoch_size, num_classes, device_num, batch_size)
|
||||
time.sleep(3)
|
||||
acc = eval(batch_size, num_classes)
|
||||
os.chdir("../")
|
||||
os.system("rm -rf " + str(device_id))
|
||||
print("End training...")
|
||||
assert (acc['acc'] > 0.35)
|
|
@ -47,7 +47,7 @@ TEST_F(MindDataTestSkipOp, TestSkipOpFuntions) {
|
|||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
// SkipOp
|
||||
std::shared_ptr<SkipOp> skip_op = std::make_shared<SkipOp>(5);
|
||||
std::shared_ptr<SkipOp> skip_op = std::make_shared<SkipOp>(5, 2);
|
||||
rc = my_tree->AssociateNode(skip_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
|
@ -51,7 +50,7 @@ def generator_md():
|
|||
|
||||
|
||||
def test_generator_skip():
|
||||
ds1 = ds.GeneratorDataset(generator_md, ["data"])
|
||||
ds1 = ds.GeneratorDataset(generator_md, ["data"], num_parallel_workers=4)
|
||||
|
||||
# Here ds1 should be [3, 4]
|
||||
ds1 = ds1.skip(3)
|
||||
|
@ -60,6 +59,7 @@ def test_generator_skip():
|
|||
for data in ds1:
|
||||
buf.append(data[0][0])
|
||||
assert len(buf) == 2
|
||||
assert buf == [3, 4]
|
||||
|
||||
|
||||
def test_skip_1():
|
||||
|
@ -72,6 +72,7 @@ def test_skip_1():
|
|||
for data in ds1:
|
||||
buf.append(data[0][0])
|
||||
assert len(buf) == 0
|
||||
assert buf == []
|
||||
|
||||
|
||||
def test_skip_2():
|
||||
|
@ -84,6 +85,7 @@ def test_skip_2():
|
|||
for data in ds1:
|
||||
buf.append(data[0][0])
|
||||
assert len(buf) == 5
|
||||
assert buf == [0, 1, 2, 3, 4]
|
||||
|
||||
|
||||
def test_skip_repeat_1():
|
||||
|
@ -99,6 +101,7 @@ def test_skip_repeat_1():
|
|||
for data in ds1:
|
||||
buf.append(data[0][0])
|
||||
assert len(buf) == 7
|
||||
assert buf == [3, 4, 0, 1, 2, 3, 4]
|
||||
|
||||
|
||||
def test_skip_repeat_2():
|
||||
|
@ -114,6 +117,7 @@ def test_skip_repeat_2():
|
|||
for data in ds1:
|
||||
buf.append(data[0][0])
|
||||
assert len(buf) == 4
|
||||
assert buf == [3, 4, 3, 4]
|
||||
|
||||
|
||||
def test_skip_repeat_3():
|
||||
|
@ -132,6 +136,62 @@ def test_skip_repeat_3():
|
|||
for data in ds1:
|
||||
buf.append(data[0][0])
|
||||
assert len(buf) == 6
|
||||
assert buf == [3, 4, 3, 4, 3, 4]
|
||||
|
||||
def test_skip_take_1():
|
||||
ds1 = ds.GeneratorDataset(generator_md, ["data"])
|
||||
|
||||
# Here ds1 should be [0, 1, 2, 3]
|
||||
ds1 = ds1.take(4)
|
||||
|
||||
# Here ds1 should be [2, 3]
|
||||
ds1 = ds1.skip(2)
|
||||
|
||||
buf = []
|
||||
for data in ds1:
|
||||
buf.append(data[0][0])
|
||||
assert len(buf) == 2
|
||||
assert buf == [2, 3]
|
||||
|
||||
def test_skip_take_2():
|
||||
ds1 = ds.GeneratorDataset(generator_md, ["data"])
|
||||
|
||||
# Here ds1 should be [2, 3, 4]
|
||||
ds1 = ds1.skip(2)
|
||||
|
||||
# Here ds1 should be [2, 3]
|
||||
ds1 = ds1.take(2)
|
||||
|
||||
buf = []
|
||||
for data in ds1:
|
||||
buf.append(data[0][0])
|
||||
assert len(buf) == 2
|
||||
assert buf == [2, 3]
|
||||
|
||||
|
||||
def generator_1d():
|
||||
for i in range(64):
|
||||
yield (np.array([i]), )
|
||||
|
||||
def test_skip_filter_1():
|
||||
dataset = ds.GeneratorDataset(generator_1d, ['data'])
|
||||
dataset = dataset.skip(5)
|
||||
dataset = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4)
|
||||
|
||||
buf = []
|
||||
for item in dataset:
|
||||
buf.append(item[0][0])
|
||||
assert buf == [5, 6, 7, 8, 9, 10]
|
||||
|
||||
def test_skip_filter_2():
|
||||
dataset = ds.GeneratorDataset(generator_1d, ['data'])
|
||||
dataset = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4)
|
||||
dataset = dataset.skip(5)
|
||||
|
||||
buf = []
|
||||
for item in dataset:
|
||||
buf.append(item[0][0])
|
||||
assert buf == [5, 6, 7, 8, 9, 10]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -142,3 +202,7 @@ if __name__ == "__main__":
|
|||
test_skip_repeat_1()
|
||||
test_skip_repeat_2()
|
||||
test_skip_repeat_3()
|
||||
test_skip_take_1()
|
||||
test_skip_take_2()
|
||||
test_skip_filter_1()
|
||||
test_skip_filter_2()
|
||||
|
|
Loading…
Reference in New Issue