forked from mindspore-Ecosystem/mindspore
[MD] Add more UT and a ST for failover reset
This commit is contained in:
parent
e30446cca0
commit
583eb32723
|
@ -17,10 +17,12 @@
|
|||
#include "minddata/dataset/engine/opt/pre/add_skip_pass.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -73,7 +75,7 @@ Status AddSkipPass::InjectionFinder::VisitAfter(std::shared_ptr<TransferNode> no
|
|||
Status AddSkipPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) {
|
||||
RETURN_UNEXPECTED_IF_NULL(root_ir);
|
||||
RETURN_UNEXPECTED_IF_NULL(modified);
|
||||
MS_LOG(INFO) << "Pre pass: Injection pass started.";
|
||||
MS_LOG(INFO) << "Pre pass: AddSkipPass started.";
|
||||
|
||||
// First, run the finder to perform any injection info before we can go ahead to drive the op injection work.
|
||||
// The finder can make updates to the AddSkipPass object.
|
||||
|
@ -89,6 +91,15 @@ Status AddSkipPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const
|
|||
CHECK_FAIL_RETURN_UNEXPECTED(dataset_size > 0, "Cannot reset the pipeline, dataset size is undefined");
|
||||
int32_t num_epochs = finder.GetNumEpochs();
|
||||
int64_t step = finder.GetStep();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(step >= 0,
|
||||
"Cannot reset the pipeline, reset step must be >= 0. step: " + std::to_string(step));
|
||||
if (step >= dataset_size * num_epochs) {
|
||||
std::string err_msg = "Cannot reset the pipeline, reset step must be less than dataset_size * num_epochs. step: " +
|
||||
std::to_string(step) + ", dataset_size: " + std::to_string(dataset_size) +
|
||||
", num_epochs: " + std::to_string(num_epochs);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
int32_t new_num_epochs = num_epochs - static_cast<int32_t>(step / dataset_size);
|
||||
int64_t skip_num = step % dataset_size;
|
||||
|
||||
|
@ -98,7 +109,7 @@ Status AddSkipPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const
|
|||
skip_node->SetFirstEpochOnly(true);
|
||||
RETURN_IF_NOT_OK(node->InsertAbove(skip_node));
|
||||
|
||||
MS_LOG(INFO) << "Pre pass: Injection pass complete.";
|
||||
MS_LOG(INFO) << "Pre pass: AddSkipPass complete.";
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.nn as nn
|
||||
from mindspore.train import Model
|
||||
from mindspore.common import set_seed
|
||||
import mindspore.dataset as ds
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore import log as logger
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
def create_np_dataset(size):
|
||||
"""
|
||||
Create dataset for train or test
|
||||
"""
|
||||
data = ds.NumpySlicesDataset(list(range(1, size + 1)), shuffle=False)
|
||||
return data
|
||||
|
||||
|
||||
def create_model():
|
||||
"""
|
||||
Define and return a simple model
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.print = P.Print()
|
||||
|
||||
def construct(self, x):
|
||||
self.print(x)
|
||||
return x
|
||||
|
||||
net = Net()
|
||||
model_ = Model(net)
|
||||
|
||||
return model_
|
||||
|
||||
|
||||
class MyCallback(Callback):
|
||||
def __init__(self, dataset_size, reset_point):
|
||||
self.dataset_size = dataset_size
|
||||
self.reset_point = reset_point
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
logger.info(f"Epoch #{cb_params.cur_epoch_num - 1} has ended")
|
||||
if cb_params.cur_epoch_num == self.reset_point:
|
||||
dataset = ds.engine.datasets._get_training_dataset() # pylint: disable=W0212
|
||||
dataset._reset(self.reset_point * self.dataset_size) # pylint: disable=W0212
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dataset_reset_sink():
|
||||
"""
|
||||
Feature: Dataset recovery
|
||||
Description: Test Dataset recovery when GPU (and sink mode) is used.
|
||||
Expectation: Training completes successfully
|
||||
"""
|
||||
data = create_np_dataset(10)
|
||||
model = create_model()
|
||||
num_epochs = 3
|
||||
reset_point = 2 # 2nd epoch
|
||||
cb = MyCallback(dataset_size=data.get_dataset_size(), reset_point=reset_point)
|
||||
model.train(num_epochs, data, callbacks=[cb])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_dataset_reset_sink()
|
|
@ -13,19 +13,72 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing pipeline Reset
|
||||
Testing dataset pipeline failover Reset
|
||||
"""
|
||||
import os
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.c_transforms as c_vision
|
||||
from util_minddataset import add_and_remove_cv_file
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
|
||||
def create_np_dataset(size):
|
||||
data = ds.NumpySlicesDataset(list(range(1, size + 1)), shuffle=False)
|
||||
dimensions = (size, 4, 3, 2)
|
||||
np_data = np.random.random(dimensions)
|
||||
data = ds.NumpySlicesDataset(np_data, shuffle=False)
|
||||
return data
|
||||
|
||||
|
||||
def util(data, num_epochs, failure_point: int, reset_step):
|
||||
def create_cifar_dataset1(size):
|
||||
data_dir = "../data/dataset/testCifar100Data"
|
||||
pad_size = 100
|
||||
crop_size = 64
|
||||
data = ds.Cifar100Dataset(data_dir, num_samples=size, shuffle=False)
|
||||
data = data.project(["image"])
|
||||
pad_op = c_vision.Pad(pad_size)
|
||||
data = data.map(operations=pad_op, input_columns=["image"])
|
||||
crop_op = c_vision.CenterCrop(crop_size)
|
||||
data = data.map(operations=crop_op, input_columns=["image"])
|
||||
return data
|
||||
|
||||
|
||||
def create_cifar_dataset2(size):
|
||||
data_dir = "../data/dataset/testCifar100Data"
|
||||
pad_size = 100
|
||||
crop_size = 64
|
||||
repeat_count = 2
|
||||
data = ds.Cifar100Dataset(data_dir, num_samples=size, shuffle=False)
|
||||
data = data.repeat(repeat_count)
|
||||
data = data.project(["image"])
|
||||
pad_op = c_vision.Pad(pad_size)
|
||||
data = data.map(operations=pad_op, input_columns=["image"])
|
||||
crop_op = c_vision.CenterCrop(crop_size)
|
||||
data = data.map(operations=crop_op, input_columns=["image"])
|
||||
return data
|
||||
|
||||
|
||||
def create_imagenet_dataset(size):
|
||||
data_dir = "../data/dataset/testImageNetData2/train"
|
||||
batch_size = 2
|
||||
data = ds.ImageFolderDataset(data_dir, num_samples=size * batch_size, shuffle=False)
|
||||
data = data.batch(batch_size)
|
||||
data = data.project(["image"])
|
||||
return data
|
||||
|
||||
|
||||
def create_minddata_dataset(size):
|
||||
columns_list = ["data"]
|
||||
num_readers = 2
|
||||
file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
|
||||
data = ds.MindDataset(file_name + "0", columns_list, num_readers, shuffle=False, num_samples=size)
|
||||
data = data.rename(input_columns=["data"], output_columns="fake_data")
|
||||
return data
|
||||
|
||||
|
||||
def run_reset(data, num_epochs, failure_point: int, reset_step: int):
|
||||
size = data.get_dataset_size()
|
||||
expected = []
|
||||
expected_itr = data.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True)
|
||||
|
@ -67,19 +120,113 @@ def util(data, num_epochs, failure_point: int, reset_step):
|
|||
np.testing.assert_array_equal(x, y)
|
||||
|
||||
|
||||
def test_reset():
|
||||
def run_reset_error(data, num_epochs: int, failure_point: int):
|
||||
itr = data.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True) # pylint: disable=unused-variable
|
||||
ds.engine.datasets._set_training_dataset(itr) # pylint: disable=W0212
|
||||
|
||||
if failure_point > 0:
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
ds.engine.datasets._reset_training_dataset(failure_point) # pylint: disable=W0212
|
||||
assert "Cannot reset the pipeline, reset step must be less than dataset_size * num_epochs." in str(err.value)
|
||||
else:
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
ds.engine.datasets._reset_training_dataset(failure_point) # pylint: disable=W0212
|
||||
assert "Cannot reset the pipeline, reset step must be >= 0." in str(err.value)
|
||||
|
||||
|
||||
def test_reset_np():
|
||||
"""
|
||||
Feature: dataset recovery
|
||||
Description: Simple test of data pipeline reset feature
|
||||
Description: Simple test of data pipeline reset feature on a pipeline with NumpySlicesDataset as a leaf node
|
||||
Expectation: same datasets after reset
|
||||
"""
|
||||
dataset_size = 5
|
||||
dataset_size = 50
|
||||
num_epochs = 3
|
||||
failure_steps = (dataset_size * num_epochs) // 10
|
||||
data = create_np_dataset(size=dataset_size)
|
||||
for failure_point in range(dataset_size * num_epochs):
|
||||
for reset_step in range(dataset_size * num_epochs):
|
||||
util(data, num_epochs=num_epochs, failure_point=failure_point, reset_step=reset_step)
|
||||
for failure_point in range(0, dataset_size * num_epochs, failure_steps):
|
||||
for reset_step in range(0, dataset_size * num_epochs, failure_steps):
|
||||
run_reset(data, num_epochs=num_epochs, failure_point=failure_point, reset_step=reset_step)
|
||||
|
||||
|
||||
def test_reset_cifar1():
|
||||
"""
|
||||
Feature: dataset recovery
|
||||
Description: Simple test of data pipeline reset feature on a pipeline with Cifar100Dataset as a leaf node (1)
|
||||
Expectation: same datasets after reset
|
||||
"""
|
||||
dataset_size = 30
|
||||
num_epochs = 2
|
||||
failure_steps = (dataset_size * num_epochs) // 5
|
||||
data = create_cifar_dataset1(size=dataset_size)
|
||||
for failure_point in range(0, dataset_size * num_epochs, failure_steps):
|
||||
for reset_step in range(0, dataset_size * num_epochs, failure_steps):
|
||||
run_reset(data, num_epochs=num_epochs, failure_point=failure_point, reset_step=reset_step)
|
||||
|
||||
|
||||
def test_reset_cifar2():
|
||||
"""
|
||||
Feature: dataset recovery
|
||||
Description: Simple test of data pipeline reset feature on a pipeline with Cifar100Dataset as a leaf node (2)
|
||||
Expectation: same datasets after reset
|
||||
"""
|
||||
dataset_size = 30
|
||||
num_epochs = 3
|
||||
failure_steps = (dataset_size * num_epochs) // 5
|
||||
data = create_cifar_dataset2(size=dataset_size)
|
||||
for failure_point in range(0, dataset_size * num_epochs, failure_steps):
|
||||
for reset_step in range(0, dataset_size * num_epochs, failure_steps):
|
||||
run_reset(data, num_epochs=num_epochs, failure_point=failure_point, reset_step=reset_step)
|
||||
|
||||
|
||||
def test_reset_imagenet():
|
||||
"""
|
||||
Feature: dataset recovery
|
||||
Description: Simple test of data pipeline reset feature on a pipeline with ImageFolderDataset as a leaf node
|
||||
Expectation: same datasets after reset
|
||||
"""
|
||||
dataset_size = 3
|
||||
num_epochs = 4
|
||||
failure_steps = (dataset_size * num_epochs) // 4
|
||||
data = create_imagenet_dataset(size=dataset_size)
|
||||
for failure_point in range(0, dataset_size * num_epochs, failure_steps):
|
||||
for reset_step in range(0, dataset_size * num_epochs, failure_steps):
|
||||
run_reset(data, num_epochs=num_epochs, failure_point=failure_point, reset_step=reset_step)
|
||||
|
||||
|
||||
def test_reset_mindrecord(add_and_remove_cv_file): # pylint: disable=unused-argument, redefined-outer-name
|
||||
"""
|
||||
Feature: dataset recovery
|
||||
Description: Simple test of data pipeline reset feature on a pipeline with MindDataset as a leaf node
|
||||
Expectation: same datasets after reset
|
||||
"""
|
||||
dataset_size = 10
|
||||
num_epochs = 3
|
||||
failure_steps = (dataset_size * num_epochs) // 10
|
||||
data = create_minddata_dataset(size=dataset_size)
|
||||
for failure_point in range(0, dataset_size * num_epochs, failure_steps):
|
||||
for reset_step in range(0, dataset_size * num_epochs, failure_steps):
|
||||
run_reset(data, num_epochs=num_epochs, failure_point=failure_point, reset_step=reset_step)
|
||||
|
||||
|
||||
def test_reset_np_error():
|
||||
"""
|
||||
Feature: dataset recovery
|
||||
Description: Simple test of data pipeline reset feature for error cases (step is negative, or larger than expected)
|
||||
Expectation: failures are detected properly and correct error message is produced
|
||||
"""
|
||||
dataset_size = 100
|
||||
num_epochs = 3
|
||||
failure_points = (-1000, -300, -99, -5, 300, 301, 1000)
|
||||
data = create_np_dataset(size=dataset_size)
|
||||
for failure_point in failure_points:
|
||||
run_reset_error(data, num_epochs=num_epochs, failure_point=failure_point)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_reset()
|
||||
test_reset_np()
|
||||
test_reset_cifar1()
|
||||
test_reset_cifar2()
|
||||
test_reset_imagenet()
|
||||
test_reset_mindrecord(add_and_remove_cv_file)
|
||||
test_reset_np_error()
|
||||
|
|
Loading…
Reference in New Issue