!49142 [MD] Add new API to debug mode and bugfix

Merge pull request !49142 from TinaMengtingZhang/dev_md_pipeline_debug_ops
This commit is contained in:
i-robot 2023-03-02 01:50:43 +00:00 committed by Gitee
commit cee9da14ae
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
13 changed files with 392 additions and 63 deletions

View File

@ -21,6 +21,7 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
@ -75,5 +76,21 @@ Status ShuffleNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNo
*result = std::make_shared<ShuffleNode>(ds, buffer_size, reset_every_epoch);
return Status::OK();
}
// Visitor accepting method for IRNodePass
Status ShuffleNode::Accept(IRNodePass *const p, bool *const modified) {
RETURN_UNEXPECTED_IF_NULL(p);
RETURN_UNEXPECTED_IF_NULL(modified);
// Downcast shared pointer then call visitor
return p->Visit(shared_from_base<ShuffleNode>(), modified);
}
// Visitor accepting method for IRNodePass
Status ShuffleNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
RETURN_UNEXPECTED_IF_NULL(p);
RETURN_UNEXPECTED_IF_NULL(modified);
// Downcast shared pointer then call visitor
return p->VisitAfter(shared_from_base<ShuffleNode>(), modified);
}
} // namespace dataset
} // namespace mindspore

View File

@ -57,6 +57,10 @@ class ShuffleNode : public DatasetNode {
uint32_t ShuffleSeed() const { return shuffle_seed_; }
bool ResetEveryEpoch() const { return reset_every_epoch_; }
/// \brief Setter function for shuffle_seed_
/// \param[in] shuffle_seed The shuffle seed value to be set
void SetShuffleSeed(uint32_t shuffle_seed) { shuffle_seed_ = shuffle_seed; }
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
@ -70,6 +74,18 @@ class ShuffleNode : public DatasetNode {
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
std::shared_ptr<DatasetNode> *result);
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(IRNodePass *const p, bool *const modified) override;
/// \brief Base-class override for accepting IRNodePass visitor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status AcceptAfter(IRNodePass *const p, bool *const modified) override;
private:
int32_t shuffle_size_;
uint32_t shuffle_seed_;

View File

@ -19,10 +19,12 @@
#include <string>
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include "minddata/dataset/include/dataset/datasets.h"
namespace mindspore {
namespace dataset {
constexpr uint32_t kSeedValue = 1;
bool DebugModePass::DebugPass::RemoveCache(std::shared_ptr<DatasetNode> node) const {
// remove DatasetNode cache
bool ret = false;
@ -51,6 +53,19 @@ Status DebugModePass::DebugPass::Visit(std::shared_ptr<MapNode> node, bool *cons
return Status::OK();
}
Status DebugModePass::DebugPass::Visit(std::shared_ptr<ShuffleNode> node, bool *const modified) {
// Debug mode requires deterministic result. Replace shuffle_seed in Shuffle node with a fixed internal seed if users
// didn't set a global config seed. (Global seed has not been configured by this time. So GlobalContext still returns
// the user configured value.)
uint32_t seed = GlobalContext::config_manager()->seed();
if (seed == std::mt19937::default_seed) {
MS_LOG(INFO) << "Replace shuffle_seed of Shuffle node with internal seed: " << std::to_string(kSeedValue) << ".";
(void)node->SetShuffleSeed(kSeedValue);
*modified = true;
}
return Status::OK();
}
Status DebugModePass::DebugPass::Visit(std::shared_ptr<DatasetNode> node, bool *const modified) {
*modified = RemoveCache(node);
return Status::OK();
@ -67,7 +82,6 @@ Status DebugModePass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *cons
// Debug mode requires the deterministic result. Set seed if users have not done so.
uint32_t seed = GlobalContext::config_manager()->seed();
if (seed == std::mt19937::default_seed) {
int8_t kSeedValue = 1;
MS_LOG(WARNING) << "Debug mode is enabled. Set seed to ensure deterministic results. Seed value: "
<< std::to_string(kSeedValue) << ".";
GlobalContext::config_manager()->set_seed(kSeedValue);

View File

@ -52,6 +52,12 @@ class DebugModePass : public IRTreePass {
/// \return Status code
Status Visit(std::shared_ptr<MapNode> node, bool *const modified) override;
/// \brief Runs a pass on ShuffleNode
/// \param[in] node The node being visited
/// \param[in, out] *modified indicates if the node was changed at all
/// \return Status code
Status Visit(std::shared_ptr<ShuffleNode> node, bool *const modified) override;
/// \brief Runs a pass on DatasetNode
/// \param[in] node The node being visited
/// \param[in, out] *modified indicates if the node was changed at all

View File

@ -31,6 +31,7 @@ import numpy
import mindspore._c_dataengine as cde
from mindspore import log as logger
from mindspore.dataset.core.validator_helpers import replace_none, type_check
from mindspore.dataset.debug import DebugHook, PrintMetaDataHook
__all__ = ['set_sending_batches', 'load', '_init_device_info',
'set_seed', 'get_seed',
@ -54,6 +55,7 @@ INT32_MAX = 2147483647
UINT32_MAX = 4294967295
_config = cde.GlobalContext.config_manager()
_debug_context = {}
def _init_device_info():
@ -840,7 +842,7 @@ def get_fast_recovery():
return _config.get_fast_recovery()
def set_debug_mode(debug_mode_flag):
def set_debug_mode(debug_mode_flag: bool, debug_hook_list: list = None):
"""
Set the debug_mode flag of the dataset pipeline. When enabled, the dataset pipeline is run synchronously and
sequentially with a single thread.
@ -866,18 +868,60 @@ def set_debug_mode(debug_mode_flag):
Args:
debug_mode_flag (bool): Whether dataset pipeline debug mode is enabled, which forces the pipeline
to run synchronously and sequentially.
debug_hook_list (list[DebugHook]): a list of debug hook objects to be inserted before and after each
transform operation in map operation. Default: None, which means to use `[PrintMetaDataHook]`,
which prints shape/size/type of each input/output data of each transformation.
Raises:
TypeError: If `debug_mode_flag` is not a boolean data type.
TypeError: If `debug_hook_list` is not a list type.
TypeError: If any item in `debug_hook_list` is not DebugHook type.
Examples:
1. Enable dataset pipeline debug mode and use default debug hook.
>>> import mindspore.dataset as ds
>>> # Print shape and type of input/output data of each transform op in map operator.
>>> ds.config.set_debug_mode(True)
2. Enable dataset pipeline debug mode and use pre-defined debug hook provided by MindData.
>>> import mindspore.dataset as ds
>>> import mindspore.dataset.debug as debug
>>> ds.config.set_debug_mode(True, debug_hook_list=[debug.PrintDataHook()])
3. Enable dataset pipeline debug mode and use user-defined debug hook. It must define a
class inherited from DebugHook.
>>> import mindspore.dataset as ds
>>> import mindspore.dataset.debug as debug
>>> class CustomizedHook(debug.DebugHook):
>>> def __init__(self):
>>> super().__init__()
>>> def compute(self, *args):
>>> # Add your debugging code here.
>>> return args
>>> ds.config.set_debug_mode(True, debug_hook_list=[CustomizedHook()])
4. Enable dataset pipeline debug mode and use user-defined debug hook and insert by users manually.
>>> import mindspore.dataset as ds
>>> ds.config.set_debug_mode(True)
>>> dataset = ds.ImageFolderDataset(...)
>>> # the debug hook is added after `Decode` operation.
>>> dataset.map([Decode(), CustomizedHook(), CenterCrop()])
"""
if not isinstance(debug_mode_flag, bool):
raise TypeError("debug_mode_flag isn't of type boolean.")
if not debug_hook_list:
debug_hook_list = [PrintMetaDataHook()]
if not isinstance(debug_hook_list, list):
raise TypeError("debug_hook_list is not a list.")
for debug_func in debug_hook_list:
if not isinstance(debug_func, DebugHook):
raise TypeError("All items in debug_hook_list must be of type DebugHook.")
if debug_mode_flag:
logger.warning("Dataset pipeline debug mode is enabled. Performance will be impacted because the pipeline"
" will be running in a single thread.")
if debug_hook_list:
_debug_context["debug_hook_list"] = debug_hook_list
_config.set_debug_mode(debug_mode_flag)
@ -894,6 +938,17 @@ def get_debug_mode():
return _config.get_debug_mode()
def _get_debug_hook_list():
"""
INTERNAL USE ONLY!
Get value of debug_hook_list.
Returns:
list, the debug hook objects to be inserted in map operation to debug inputs/outputs of each transform.
"""
return _debug_context.get("debug_hook_list")
class ErrorSamplesMode(IntEnum):
"""
An enumeration for `error_samples_mode` .

View File

@ -0,0 +1,21 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Init file for dataset pipeline debug mode"""
from __future__ import absolute_import
from mindspore.dataset.debug.debug_hook import DebugHook
from mindspore.dataset.debug.pre_defined_hook import PrintMetaDataHook, PrintDataHook
__all__ = ["DebugHook", "PrintMetaDataHook", "PrintDataHook"]

View File

@ -14,18 +14,17 @@
# ==============================================================================
"""
This module defines the class for minddata pipeline debugger.
class DebugWrapper is not exposed to users as an external API.
class DebugHook is not exposed to users as an external API.
"""
import collections
import numpy as np
from PIL import Image
from abc import ABC, abstractmethod
from mindspore import log as logger
class DebugWrapper:
class DebugHook(ABC):
"""
A class for Minddata Python Debugger.
The base class for Dataset Pipeline Python Debugger hook. All user defined hook behaviors
must inherit this base class.
To debug the input and output data of map operation in dataset pipeline, users can add
breakpoint to or single stepping in this class. They can also see the type and shape of
@ -37,27 +36,27 @@ class DebugWrapper:
def __init__(self, prev_op_name=None):
self.prev_op_name = prev_op_name
def __call__(self, x):
def __call__(self, *args):
# log op name
if self.prev_op_name:
log_message = "Debugging the output of the operation [{}].".format(self.prev_op_name)
else:
log_message = "Debugging the input of the first operation."
# log type
log_message += " The type is [{}].".format(type(x))
# log shape/size
if isinstance(x, np.ndarray):
log_message += " The shape is [{}].".format(x.shape)
elif isinstance(x, Image.Image):
log_message += " The shape is [{}].".format(x.size)
elif isinstance(x, collections.abc.Sized):
log_message += " The size is [{}].".format(len(x))
logger.info(log_message)
######################## NOTE ########################
# Add a breakpoint to the following line to inspect
# input and output of each transform.
######################################################
logger.info(log_message)
return x
self.compute(args)
return args
@abstractmethod
def compute(self, *args):
"""
Defines the debug behaviour to be performed. This method must be overridden by all subclasses.
"""
raise RuntimeError("compute() is not overridden in subclass of class DebugHook.")
def set_previous_op_name(self, prev_op_name):
self.prev_op_name = prev_op_name

View File

@ -0,0 +1,67 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This module defines the subclass of DebugHook for minddata pipeline debugger.
All these class are pre-defined for users for basic debugging purpose.
"""
import collections
import numpy as np
from PIL import Image
from mindspore import log as logger
from mindspore.dataset.debug.debug_hook import DebugHook
class PrintMetaDataHook(DebugHook):
"""
Debug hook used for MindData debug mode to print type and shape of data.
"""
def __init__(self):
super().__init__()
def compute(self, *args):
for col_idx, col in enumerate(*args):
log_message = "Column {}. ".format(col_idx)
# log type
log_message += "The type is [{}].".format(type(col))
# log shape/size
if isinstance(col, np.ndarray):
log_message += " The shape is [{}].".format(col.shape)
elif isinstance(col, Image.Image):
log_message += " The shape is [{}].".format(col.size)
elif isinstance(col, collections.abc.Sized):
log_message += " The size is [{}].".format(len(col))
logger.info(log_message)
return args
class PrintDataHook(DebugHook):
"""
Debug hook used for MindData debug mode to print data.
"""
def __init__(self):
super().__init__()
def compute(self, *args):
for col_idx, col in enumerate(*args):
log_message = "Column {}. ".format(col_idx)
if isinstance(col, Image.Image):
data = np.asarray(col)
log_message += "The data is [{}].".format(data)
else:
log_message += "The data is [{}].".format(col)
logger.info(log_message)
return args

View File

@ -59,7 +59,7 @@ import mindspore.dataset.transforms.py_transforms as py_transforms
import mindspore.dataset.transforms as transforms
from mindspore.dataset.text.utils import SentencePieceModel, DE_C_INTER_SENTENCEPIECE_MODE
from mindspore.parallel._utils import _get_device_num
from mindspore.dataset.engine.debug import DebugWrapper
from mindspore.dataset.debug import DebugHook
from . import samplers
from .iterators import DictIterator, TupleIterator, DummyIterator, check_iterator_cleanup, _set_iterator_cleanup, \
@ -70,7 +70,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_sync_wait, check_zip_dataset, check_add_column, check_concat, check_split, check_bucket_batch_by_length, \
check_save, check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_padded_batch
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
get_enable_watchdog, get_seed, set_seed, get_debug_mode, get_multiprocessing_timeout_interval
get_enable_watchdog, get_seed, set_seed, get_debug_mode, get_multiprocessing_timeout_interval, _get_debug_hook_list
from ..core.datatypes import mstype_to_detype
from ..core.validator_helpers import replace_none
from ..core.py_util_helpers import ExceptionHandler
@ -3336,6 +3336,13 @@ class MapDataset(UnionBaseDataset):
if count_new_transforms + count_pyfunc == len(operations):
prev_op = None
for op in operations:
# skip user added DebugHook to avoid changing to Py-implementation.
if self.__is_debug_hook_op(op):
if prev_op:
# manually set previous_op_name
prev_op_name = self.__parse_op_name(prev_op)
op.set_previous_op_name(prev_op_name)
continue
if op.implementation is None:
if prev_op and prev_op.implementation == Implementation.PY:
op.implementation = Implementation.PY
@ -3366,27 +3373,46 @@ class MapDataset(UnionBaseDataset):
del self.process_pool
@staticmethod
def __insert_debug_wrapper(operations):
def __parse_op_name(op):
"""
Insert DebuggerWrapper before and after each op if debug mode is on.
Utility method to get operation name.
"""
if not get_debug_mode():
return operations
inserted_func = transforms.py_transforms_util.FuncWrapper(DebugWrapper())
inserted_func.implementation = Implementation.PY
inserted_operations = [inserted_func]
for op in operations:
if isinstance(op, transforms.py_transforms_util.FuncWrapper):
try:
op_name = op.transform.__name__
except Exception:
op_name = op.transform.__class__.__name__
else:
op_name = op.__class__.__name__
inserted_func = transforms.py_transforms_util.FuncWrapper(DebugWrapper(op_name))
inserted_func.implementation = Implementation.PY
inserted_operations.extend([op, inserted_func])
return inserted_operations
op_name = ""
if isinstance(op, transforms.py_transforms_util.FuncWrapper):
try:
op_name = op.transform.__name__
except (Exception,):
op_name = op.transform.__class__.__name__
else:
op_name = op.__class__.__name__
return op_name
@staticmethod
def __construct_debug_hook(previous_op_name=None):
"""
Wrap debug hook into FuncWrapper.
"""
inserted_functions = []
debug_hook_list = _get_debug_hook_list()
if debug_hook_list:
for fn in debug_hook_list:
new_fn = copy.deepcopy(fn)
new_fn.set_previous_op_name(previous_op_name)
inserted_func = transforms.py_transforms_util.FuncWrapper(new_fn)
inserted_func.implementation = Implementation.PY
inserted_functions.append(inserted_func)
return inserted_functions
@staticmethod
def __is_debug_hook_op(op):
"""
Check if the op is user added DebugHook and skip it to avoid changing transforms implementation.
"""
if isinstance(op, DebugHook):
if not get_debug_mode():
raise ValueError("It is not allowed to inject DebugHook object in non-debug mode.")
return True
return False
@staticmethod
def __count_pyfuncs(operations):
@ -3487,6 +3513,19 @@ class MapDataset(UnionBaseDataset):
iter_specific_operations.append(op)
self.operations = iter_specific_operations
def __insert_debug_wrapper(self, operations):
"""
Insert DebuggerWrapper before and after each op if debug mode is on.
"""
if not get_debug_mode():
return operations
inserted_operations = self.__construct_debug_hook()
for op in operations:
inserted_operations.append(op)
op_name = self.__parse_op_name(op)
inserted_operations.extend(self.__construct_debug_hook(op_name))
return inserted_operations
def __decompose_callable_operations(self):
"""
Decompose operations and build list of old legacy ops which are callable

View File

@ -45,6 +45,7 @@ from mindspore.context import ParallelMode
from mindspore.parallel._recovery_context import _set_recovery_context, _get_recovery_context
from mindspore.train.dataset_helper import DatasetHelper, connect_network_with_dataset
from mindspore.common.api import _pynative_executor
from mindspore.dataset.core.config import get_debug_mode
from mindspore.dataset.engine.datasets import _set_training_dataset, _reset_training_dataset
from mindspore.train import amp
@ -1029,6 +1030,7 @@ class Model:
if not dataset_sink_mode and _cache_enable():
raise ValueError("Embedding cache mode should run with 'dataset_sink_mode=True'.")
self._check_sink_mode_for_ds_debug_mode(dataset_sink_mode)
Validator.check_is_int(sink_size)
Validator.check_non_negative_int(epoch)
@ -1064,6 +1066,12 @@ class Model:
if _enable_distributed_mindrt():
_reset_op_id_with_offset()
@staticmethod
def _check_sink_mode_for_ds_debug_mode(dataset_sink_mode):
if get_debug_mode() and dataset_sink_mode:
raise ValueError("Dataset sink mode is not supported when dataset pipeline debug mode is on. "
"Please manually turn off sink mode.")
@staticmethod
def _check_methods_for_custom_callbacks(callbacks, current_mode):
"""

View File

@ -1,4 +1,4 @@
# Copyright 2019-2022 Huawei Technologies Co., Ltd
# Copyright 2019-2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -26,6 +26,7 @@ import mindspore.dataset.engine.iterators as it
import mindspore.dataset.transforms
import mindspore.dataset.vision as vision
import mindspore.dataset.core.config as config
import mindspore.dataset.debug as debug
from mindspore import log as logger
from util import dataset_equal
@ -538,25 +539,35 @@ def test_fast_recovery():
assert "set_fast_recovery() missing 1 required positional argument: 'fast_recovery'" in str(error_info.value)
def test_debug_mode():
def test_debug_mode_error_case():
"""
Feature: Test the debug mode setter/getter function
Description: This function only accepts a boolean as input and outputs error otherwise
Expectation: TypeError will be raised when input argument is missing or is not a boolean
Feature: Test the debug mode setter function
Description: This function only accepts a boolean as first input and list as second input, outputs error otherwise
Expectation: TypeError will be raised when input argument is missing or is invalid
"""
# set_debug_mode will raise TypeError if input is an integer
# set_debug_mode will raise TypeError if first input is an integer
config_error_func(ds.config.set_debug_mode, 0, TypeError, "debug_mode_flag isn't of type boolean.")
# set_debug_mode will raise TypeError if input is a string
# set_debug_mode will raise TypeError if first input is a string
config_error_func(ds.config.set_debug_mode, "True", TypeError, "debug_mode_flag isn't of type boolean.")
# set_debug_mode will raise TypeError if input is a tuple
# set_debug_mode will raise TypeError if first input is a tuple
config_error_func(ds.config.set_debug_mode, (True,), TypeError, "debug_mode_flag isn't of type boolean.")
# set_debug_mode will raise TypeError if input is None
# set_debug_mode will raise TypeError if first input is None
config_error_func(ds.config.set_debug_mode, None, TypeError, "debug_mode_flag isn't of type boolean.")
# set_debug_mode will raise TypeError if no input is provided
with pytest.raises(TypeError) as error_info:
ds.config.set_debug_mode()
assert "set_debug_mode() missing 1 required positional argument: 'debug_mode_flag'" in str(error_info.value)
# set_debug_mode will raise TypeError if second input is not valid
with pytest.raises(TypeError) as error_info:
ds.config.set_debug_mode(True, debug.PrintDataHook())
assert "debug_hook_list is not a list" in str(error_info.value)
def func():
pass
with pytest.raises(TypeError) as error_info:
ds.config.set_debug_mode(True, [func])
assert "All items in debug_hook_list must be of type DebugHook" in str(error_info.value)
def test_error_samples_mode():
"""
@ -608,5 +619,5 @@ if __name__ == '__main__':
test_multiprocessing_timeout_interval()
test_config_bool_type_error()
test_fast_recovery()
test_debug_mode()
test_debug_mode_error_case()
test_error_samples_mode()

View File

@ -21,8 +21,10 @@ import pytest
import mindspore.dataset as ds
import mindspore.dataset.transforms as transforms
import mindspore.dataset.vision as vision
import mindspore.nn as nn
from mindspore.dataset.vision import Inter
from mindspore import log as logger
from mindspore.train import Model
# Need to run all these tests in separate processes since
# the global configuration setting of debug_mode may impact other tests running in parallel.
@ -255,9 +257,6 @@ def test_pipeline_debug_mode_generator_pipeline():
Expectation: Output is equal to the expected output
"""
logger.info("test_pipeline_debug_mode_generator_pipeline")
# Note: set seed to make sure consistent results of Shuffle op. Even in debug mode, seed has
# been set internally in IR pre-pass, results are still random (needs further investigation).
ds.set_seed(8)
ds1 = ds.GeneratorDataset(generator_md, ["data"])
# Here ds1 should be [2, 3, 4, 5, 6, 7, 8, 9]
@ -267,6 +266,7 @@ def test_pipeline_debug_mode_generator_pipeline():
ds1 = ds1.take(7)
# do shuffle followed by batch
# Note: Since the internal seed is set in debug mode, the consistency of results after ShuffleOp can be ensured.
ds1 = ds1.shuffle(5)
ds1 = ds1.batch(3, drop_remainder=True)
@ -274,7 +274,7 @@ def test_pipeline_debug_mode_generator_pipeline():
for data in ds1.create_tuple_iterator(num_epochs=1, output_numpy=True):
buf.append(data[0])
assert len(buf) == 2
out_expect = [[[6], [3], [8]], [[4], [7], [2]]]
out_expect = [[[5], [4], [2]], [[7], [8], [3]]]
np.testing.assert_array_equal(buf, out_expect)
@ -443,15 +443,13 @@ def test_pipeline_debug_mode_shuffle():
Expectation: Successful.
"""
logger.info("test_pipeline_debug_mode_shuffle")
# Note: set seed to make sure consistent results of Shuffle op. Even in debug mode, seed has
# been set internally in IR pre-pass, results are still random (needs further investigation).
ds.set_seed(150)
buffer_size = 5
data = ds.TextFileDataset(TEXTFILE_DATA, shuffle=False)
# Note: Since the internal seed is set in debug mode, the consistency of results after ShuffleOp can be ensured.
data = data.shuffle(buffer_size=buffer_size)
out_expect = ["Good luck to everyone.", "Be happy every day.", "This is a text file.",
"Another file.", "End of file."]
out_expect = ["End of file.", "Be happy every day.", "This is a text file.", "Good luck to everyone.",
"Another file."]
num_rows = 0
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
assert item["text"] == out_expect[num_rows]
@ -782,7 +780,7 @@ def test_pipeline_debug_mode_batch_map_get_epoch_batch_num():
def test_pipeline_debug_mode_batch_map_get_epoch_num():
"""
Feature: Batch op
Feature: Dataset Debug Mode
Description: Test basic map Batch op with per_batch_map function calling get_epoch_num()
Expectation: Output is equal to the expected output
"""
@ -841,6 +839,26 @@ def test_pipeline_debug_mode_batch_map_get_epoch_num():
[[0], [-1]], [[-2], [-3]], [[0], [-1]], [[-2], [-3]]])
def test_pipeline_debug_mode_dataset_sink_not_support():
"""
Feature: Dataset Debug Mode
Description: Test dataset sink mode with debug mode enabled.
Expectation: Raise ValueError and give proper log.
"""
dataset = ds.Cifar100Dataset("../data/dataset/testCifar100Data", num_samples=100)
def create_model():
class Net(nn.Cell):
def construct(self, x):
return x
net = Net()
return Model(net)
model = create_model()
with pytest.raises(ValueError) as error_info:
model.train(2, dataset, dataset_sink_mode=True)
assert "Dataset sink mode is not supported when dataset pipeline debug mode is on. "\
"Please manually turn off sink mode" in str(error_info.value)
if __name__ == '__main__':
setup_function()
test_pipeline_debug_mode_tuple()
@ -871,4 +889,5 @@ if __name__ == '__main__':
test_pipeline_debug_mode_cifar100_per_batch_map_mp()
test_pipeline_debug_mode_batch_map_get_epoch_batch_num()
test_pipeline_debug_mode_batch_map_get_epoch_num()
test_pipeline_debug_mode_dataset_sink_not_support()
teardown_function()

View File

@ -0,0 +1,57 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test debug hook of debug mode
"""
import pytest
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
import mindspore.dataset.debug as dbg
# Need to run all these tests in separate processes since
# the global configuration setting of debug_mode may impact other tests running in parallel.
pytestmark = pytest.mark.forked
@pytest.mark.parametrize("debug_mode_flag, debug_hook_list",
[(True, [dbg.PrintMetaDataHook()]),
(True, [dbg.PrintDataHook()]),
(True, [])])
def test_debug_mode_hook(debug_mode_flag, debug_hook_list):
"""
Feature: Test the debug mode setter function
Description: Test valid debug hook case for debug mode
Expectation: Success
"""
# get original configs to restore after running is done.
origin_debug_mode = ds.config.get_debug_mode()
origin_seed = ds.config.get_seed()
# set debug flag and hook
ds.config.set_debug_mode(debug_mode_flag=debug_mode_flag, debug_hook_list=debug_hook_list)
dataset = ds.ImageFolderDataset("../data/dataset/testPK/data", num_samples=5)
dataset = dataset.map(operations=[vision.Decode(False), vision.CenterCrop((225, 225))])
for _ in dataset.create_dict_iterator(num_epochs=1):
pass
# restore configs
ds.config.set_debug_mode(origin_debug_mode)
ds.config.set_seed(origin_seed)
if __name__ == '__main__':
test_debug_mode_hook(True, [dbg.PrintMetaDataHook()])
test_debug_mode_hook(True, [dbg.PrintDataHook()])
test_debug_mode_hook(True, [])