forked from mindspore-Ecosystem/mindspore
!49142 [MD] Add new API to debug mode and bugfix
Merge pull request !49142 from TinaMengtingZhang/dev_md_pipeline_debug_ops
This commit is contained in:
commit
cee9da14ae
|
@ -21,6 +21,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
|
@ -75,5 +76,21 @@ Status ShuffleNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNo
|
|||
*result = std::make_shared<ShuffleNode>(ds, buffer_size, reset_every_epoch);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor accepting method for IRNodePass
|
||||
Status ShuffleNode::Accept(IRNodePass *const p, bool *const modified) {
|
||||
RETURN_UNEXPECTED_IF_NULL(p);
|
||||
RETURN_UNEXPECTED_IF_NULL(modified);
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->Visit(shared_from_base<ShuffleNode>(), modified);
|
||||
}
|
||||
|
||||
// Visitor accepting method for IRNodePass
|
||||
Status ShuffleNode::AcceptAfter(IRNodePass *const p, bool *const modified) {
|
||||
RETURN_UNEXPECTED_IF_NULL(p);
|
||||
RETURN_UNEXPECTED_IF_NULL(modified);
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->VisitAfter(shared_from_base<ShuffleNode>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -57,6 +57,10 @@ class ShuffleNode : public DatasetNode {
|
|||
uint32_t ShuffleSeed() const { return shuffle_seed_; }
|
||||
bool ResetEveryEpoch() const { return reset_every_epoch_; }
|
||||
|
||||
/// \brief Setter function for shuffle_seed_
|
||||
/// \param[in] shuffle_seed The shuffle seed value to be set
|
||||
void SetShuffleSeed(uint32_t shuffle_seed) { shuffle_seed_ = shuffle_seed; }
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
|
@ -70,6 +74,18 @@ class ShuffleNode : public DatasetNode {
|
|||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result);
|
||||
|
||||
/// \brief Base-class override for accepting IRNodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status Accept(IRNodePass *const p, bool *const modified) override;
|
||||
|
||||
/// \brief Base-class override for accepting IRNodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status AcceptAfter(IRNodePass *const p, bool *const modified) override;
|
||||
|
||||
private:
|
||||
int32_t shuffle_size_;
|
||||
uint32_t shuffle_seed_;
|
||||
|
|
|
@ -19,10 +19,12 @@
|
|||
#include <string>
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/include/dataset/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
constexpr uint32_t kSeedValue = 1;
|
||||
bool DebugModePass::DebugPass::RemoveCache(std::shared_ptr<DatasetNode> node) const {
|
||||
// remove DatasetNode cache
|
||||
bool ret = false;
|
||||
|
@ -51,6 +53,19 @@ Status DebugModePass::DebugPass::Visit(std::shared_ptr<MapNode> node, bool *cons
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DebugModePass::DebugPass::Visit(std::shared_ptr<ShuffleNode> node, bool *const modified) {
|
||||
// Debug mode requires deterministic result. Replace shuffle_seed in Shuffle node with a fixed internal seed if users
|
||||
// didn't set a global config seed. (Global seed has not been configured by this time. So GlobalContext still returns
|
||||
// the user configured value.)
|
||||
uint32_t seed = GlobalContext::config_manager()->seed();
|
||||
if (seed == std::mt19937::default_seed) {
|
||||
MS_LOG(INFO) << "Replace shuffle_seed of Shuffle node with internal seed: " << std::to_string(kSeedValue) << ".";
|
||||
(void)node->SetShuffleSeed(kSeedValue);
|
||||
*modified = true;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DebugModePass::DebugPass::Visit(std::shared_ptr<DatasetNode> node, bool *const modified) {
|
||||
*modified = RemoveCache(node);
|
||||
return Status::OK();
|
||||
|
@ -67,7 +82,6 @@ Status DebugModePass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *cons
|
|||
// Debug mode requires the deterministic result. Set seed if users have not done so.
|
||||
uint32_t seed = GlobalContext::config_manager()->seed();
|
||||
if (seed == std::mt19937::default_seed) {
|
||||
int8_t kSeedValue = 1;
|
||||
MS_LOG(WARNING) << "Debug mode is enabled. Set seed to ensure deterministic results. Seed value: "
|
||||
<< std::to_string(kSeedValue) << ".";
|
||||
GlobalContext::config_manager()->set_seed(kSeedValue);
|
||||
|
|
|
@ -52,6 +52,12 @@ class DebugModePass : public IRTreePass {
|
|||
/// \return Status code
|
||||
Status Visit(std::shared_ptr<MapNode> node, bool *const modified) override;
|
||||
|
||||
/// \brief Runs a pass on ShuffleNode
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[in, out] *modified indicates if the node was changed at all
|
||||
/// \return Status code
|
||||
Status Visit(std::shared_ptr<ShuffleNode> node, bool *const modified) override;
|
||||
|
||||
/// \brief Runs a pass on DatasetNode
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[in, out] *modified indicates if the node was changed at all
|
||||
|
|
|
@ -31,6 +31,7 @@ import numpy
|
|||
import mindspore._c_dataengine as cde
|
||||
from mindspore import log as logger
|
||||
from mindspore.dataset.core.validator_helpers import replace_none, type_check
|
||||
from mindspore.dataset.debug import DebugHook, PrintMetaDataHook
|
||||
|
||||
__all__ = ['set_sending_batches', 'load', '_init_device_info',
|
||||
'set_seed', 'get_seed',
|
||||
|
@ -54,6 +55,7 @@ INT32_MAX = 2147483647
|
|||
UINT32_MAX = 4294967295
|
||||
|
||||
_config = cde.GlobalContext.config_manager()
|
||||
_debug_context = {}
|
||||
|
||||
|
||||
def _init_device_info():
|
||||
|
@ -840,7 +842,7 @@ def get_fast_recovery():
|
|||
return _config.get_fast_recovery()
|
||||
|
||||
|
||||
def set_debug_mode(debug_mode_flag):
|
||||
def set_debug_mode(debug_mode_flag: bool, debug_hook_list: list = None):
|
||||
"""
|
||||
Set the debug_mode flag of the dataset pipeline. When enabled, the dataset pipeline is run synchronously and
|
||||
sequentially with a single thread.
|
||||
|
@ -866,18 +868,60 @@ def set_debug_mode(debug_mode_flag):
|
|||
Args:
|
||||
debug_mode_flag (bool): Whether dataset pipeline debug mode is enabled, which forces the pipeline
|
||||
to run synchronously and sequentially.
|
||||
debug_hook_list (list[DebugHook]): a list of debug hook objects to be inserted before and after each
|
||||
transform operation in map operation. Default: None, which means to use `[PrintMetaDataHook]`,
|
||||
which prints shape/size/type of each input/output data of each transformation.
|
||||
|
||||
Raises:
|
||||
TypeError: If `debug_mode_flag` is not a boolean data type.
|
||||
TypeError: If `debug_hook_list` is not a list type.
|
||||
TypeError: If any item in `debug_hook_list` is not DebugHook type.
|
||||
|
||||
Examples:
|
||||
1. Enable dataset pipeline debug mode and use default debug hook.
|
||||
>>> import mindspore.dataset as ds
|
||||
>>> # Print shape and type of input/output data of each transform op in map operator.
|
||||
>>> ds.config.set_debug_mode(True)
|
||||
|
||||
2. Enable dataset pipeline debug mode and use pre-defined debug hook provided by MindData.
|
||||
>>> import mindspore.dataset as ds
|
||||
>>> import mindspore.dataset.debug as debug
|
||||
>>> ds.config.set_debug_mode(True, debug_hook_list=[debug.PrintDataHook()])
|
||||
|
||||
3. Enable dataset pipeline debug mode and use user-defined debug hook. It must define a
|
||||
class inherited from DebugHook.
|
||||
>>> import mindspore.dataset as ds
|
||||
>>> import mindspore.dataset.debug as debug
|
||||
>>> class CustomizedHook(debug.DebugHook):
|
||||
>>> def __init__(self):
|
||||
>>> super().__init__()
|
||||
>>> def compute(self, *args):
|
||||
>>> # Add your debugging code here.
|
||||
>>> return args
|
||||
>>> ds.config.set_debug_mode(True, debug_hook_list=[CustomizedHook()])
|
||||
|
||||
4. Enable dataset pipeline debug mode and use user-defined debug hook and insert by users manually.
|
||||
>>> import mindspore.dataset as ds
|
||||
>>> ds.config.set_debug_mode(True)
|
||||
>>> dataset = ds.ImageFolderDataset(...)
|
||||
>>> # the debug hook is added after `Decode` operation.
|
||||
>>> dataset.map([Decode(), CustomizedHook(), CenterCrop()])
|
||||
"""
|
||||
if not isinstance(debug_mode_flag, bool):
|
||||
raise TypeError("debug_mode_flag isn't of type boolean.")
|
||||
if not debug_hook_list:
|
||||
debug_hook_list = [PrintMetaDataHook()]
|
||||
if not isinstance(debug_hook_list, list):
|
||||
raise TypeError("debug_hook_list is not a list.")
|
||||
for debug_func in debug_hook_list:
|
||||
if not isinstance(debug_func, DebugHook):
|
||||
raise TypeError("All items in debug_hook_list must be of type DebugHook.")
|
||||
if debug_mode_flag:
|
||||
logger.warning("Dataset pipeline debug mode is enabled. Performance will be impacted because the pipeline"
|
||||
" will be running in a single thread.")
|
||||
if debug_hook_list:
|
||||
_debug_context["debug_hook_list"] = debug_hook_list
|
||||
|
||||
_config.set_debug_mode(debug_mode_flag)
|
||||
|
||||
|
||||
|
@ -894,6 +938,17 @@ def get_debug_mode():
|
|||
return _config.get_debug_mode()
|
||||
|
||||
|
||||
def _get_debug_hook_list():
|
||||
"""
|
||||
INTERNAL USE ONLY!
|
||||
Get value of debug_hook_list.
|
||||
|
||||
Returns:
|
||||
list, the debug hook objects to be inserted in map operation to debug inputs/outputs of each transform.
|
||||
"""
|
||||
return _debug_context.get("debug_hook_list")
|
||||
|
||||
|
||||
class ErrorSamplesMode(IntEnum):
|
||||
"""
|
||||
An enumeration for `error_samples_mode` .
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Init file for dataset pipeline debug mode"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
from mindspore.dataset.debug.debug_hook import DebugHook
|
||||
from mindspore.dataset.debug.pre_defined_hook import PrintMetaDataHook, PrintDataHook
|
||||
|
||||
__all__ = ["DebugHook", "PrintMetaDataHook", "PrintDataHook"]
|
|
@ -14,18 +14,17 @@
|
|||
# ==============================================================================
|
||||
"""
|
||||
This module defines the class for minddata pipeline debugger.
|
||||
class DebugWrapper is not exposed to users as an external API.
|
||||
class DebugHook is not exposed to users as an external API.
|
||||
"""
|
||||
|
||||
import collections
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from abc import ABC, abstractmethod
|
||||
from mindspore import log as logger
|
||||
|
||||
|
||||
class DebugWrapper:
|
||||
class DebugHook(ABC):
|
||||
"""
|
||||
A class for Minddata Python Debugger.
|
||||
The base class for Dataset Pipeline Python Debugger hook. All user defined hook behaviors
|
||||
must inherit this base class.
|
||||
|
||||
To debug the input and output data of map operation in dataset pipeline, users can add
|
||||
breakpoint to or single stepping in this class. They can also see the type and shape of
|
||||
|
@ -37,27 +36,27 @@ class DebugWrapper:
|
|||
def __init__(self, prev_op_name=None):
|
||||
self.prev_op_name = prev_op_name
|
||||
|
||||
def __call__(self, x):
|
||||
def __call__(self, *args):
|
||||
# log op name
|
||||
if self.prev_op_name:
|
||||
log_message = "Debugging the output of the operation [{}].".format(self.prev_op_name)
|
||||
else:
|
||||
log_message = "Debugging the input of the first operation."
|
||||
|
||||
# log type
|
||||
log_message += " The type is [{}].".format(type(x))
|
||||
|
||||
# log shape/size
|
||||
if isinstance(x, np.ndarray):
|
||||
log_message += " The shape is [{}].".format(x.shape)
|
||||
elif isinstance(x, Image.Image):
|
||||
log_message += " The shape is [{}].".format(x.size)
|
||||
elif isinstance(x, collections.abc.Sized):
|
||||
log_message += " The size is [{}].".format(len(x))
|
||||
logger.info(log_message)
|
||||
|
||||
######################## NOTE ########################
|
||||
# Add a breakpoint to the following line to inspect
|
||||
# input and output of each transform.
|
||||
######################################################
|
||||
logger.info(log_message)
|
||||
return x
|
||||
self.compute(args)
|
||||
return args
|
||||
|
||||
@abstractmethod
|
||||
def compute(self, *args):
|
||||
"""
|
||||
Defines the debug behaviour to be performed. This method must be overridden by all subclasses.
|
||||
"""
|
||||
raise RuntimeError("compute() is not overridden in subclass of class DebugHook.")
|
||||
|
||||
def set_previous_op_name(self, prev_op_name):
|
||||
self.prev_op_name = prev_op_name
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
This module defines the subclass of DebugHook for minddata pipeline debugger.
|
||||
All these class are pre-defined for users for basic debugging purpose.
|
||||
"""
|
||||
|
||||
import collections
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from mindspore import log as logger
|
||||
from mindspore.dataset.debug.debug_hook import DebugHook
|
||||
|
||||
|
||||
class PrintMetaDataHook(DebugHook):
|
||||
"""
|
||||
Debug hook used for MindData debug mode to print type and shape of data.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def compute(self, *args):
|
||||
for col_idx, col in enumerate(*args):
|
||||
log_message = "Column {}. ".format(col_idx)
|
||||
# log type
|
||||
log_message += "The type is [{}].".format(type(col))
|
||||
|
||||
# log shape/size
|
||||
if isinstance(col, np.ndarray):
|
||||
log_message += " The shape is [{}].".format(col.shape)
|
||||
elif isinstance(col, Image.Image):
|
||||
log_message += " The shape is [{}].".format(col.size)
|
||||
elif isinstance(col, collections.abc.Sized):
|
||||
log_message += " The size is [{}].".format(len(col))
|
||||
logger.info(log_message)
|
||||
return args
|
||||
|
||||
|
||||
class PrintDataHook(DebugHook):
|
||||
"""
|
||||
Debug hook used for MindData debug mode to print data.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def compute(self, *args):
|
||||
for col_idx, col in enumerate(*args):
|
||||
log_message = "Column {}. ".format(col_idx)
|
||||
if isinstance(col, Image.Image):
|
||||
data = np.asarray(col)
|
||||
log_message += "The data is [{}].".format(data)
|
||||
else:
|
||||
log_message += "The data is [{}].".format(col)
|
||||
logger.info(log_message)
|
||||
return args
|
|
@ -59,7 +59,7 @@ import mindspore.dataset.transforms.py_transforms as py_transforms
|
|||
import mindspore.dataset.transforms as transforms
|
||||
from mindspore.dataset.text.utils import SentencePieceModel, DE_C_INTER_SENTENCEPIECE_MODE
|
||||
from mindspore.parallel._utils import _get_device_num
|
||||
from mindspore.dataset.engine.debug import DebugWrapper
|
||||
from mindspore.dataset.debug import DebugHook
|
||||
|
||||
from . import samplers
|
||||
from .iterators import DictIterator, TupleIterator, DummyIterator, check_iterator_cleanup, _set_iterator_cleanup, \
|
||||
|
@ -70,7 +70,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
|||
check_sync_wait, check_zip_dataset, check_add_column, check_concat, check_split, check_bucket_batch_by_length, \
|
||||
check_save, check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_padded_batch
|
||||
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
|
||||
get_enable_watchdog, get_seed, set_seed, get_debug_mode, get_multiprocessing_timeout_interval
|
||||
get_enable_watchdog, get_seed, set_seed, get_debug_mode, get_multiprocessing_timeout_interval, _get_debug_hook_list
|
||||
from ..core.datatypes import mstype_to_detype
|
||||
from ..core.validator_helpers import replace_none
|
||||
from ..core.py_util_helpers import ExceptionHandler
|
||||
|
@ -3336,6 +3336,13 @@ class MapDataset(UnionBaseDataset):
|
|||
if count_new_transforms + count_pyfunc == len(operations):
|
||||
prev_op = None
|
||||
for op in operations:
|
||||
# skip user added DebugHook to avoid changing to Py-implementation.
|
||||
if self.__is_debug_hook_op(op):
|
||||
if prev_op:
|
||||
# manually set previous_op_name
|
||||
prev_op_name = self.__parse_op_name(prev_op)
|
||||
op.set_previous_op_name(prev_op_name)
|
||||
continue
|
||||
if op.implementation is None:
|
||||
if prev_op and prev_op.implementation == Implementation.PY:
|
||||
op.implementation = Implementation.PY
|
||||
|
@ -3366,27 +3373,46 @@ class MapDataset(UnionBaseDataset):
|
|||
del self.process_pool
|
||||
|
||||
@staticmethod
|
||||
def __insert_debug_wrapper(operations):
|
||||
def __parse_op_name(op):
|
||||
"""
|
||||
Insert DebuggerWrapper before and after each op if debug mode is on.
|
||||
Utility method to get operation name.
|
||||
"""
|
||||
if not get_debug_mode():
|
||||
return operations
|
||||
inserted_func = transforms.py_transforms_util.FuncWrapper(DebugWrapper())
|
||||
inserted_func.implementation = Implementation.PY
|
||||
inserted_operations = [inserted_func]
|
||||
for op in operations:
|
||||
if isinstance(op, transforms.py_transforms_util.FuncWrapper):
|
||||
try:
|
||||
op_name = op.transform.__name__
|
||||
except Exception:
|
||||
op_name = op.transform.__class__.__name__
|
||||
else:
|
||||
op_name = op.__class__.__name__
|
||||
inserted_func = transforms.py_transforms_util.FuncWrapper(DebugWrapper(op_name))
|
||||
inserted_func.implementation = Implementation.PY
|
||||
inserted_operations.extend([op, inserted_func])
|
||||
return inserted_operations
|
||||
op_name = ""
|
||||
if isinstance(op, transforms.py_transforms_util.FuncWrapper):
|
||||
try:
|
||||
op_name = op.transform.__name__
|
||||
except (Exception,):
|
||||
op_name = op.transform.__class__.__name__
|
||||
else:
|
||||
op_name = op.__class__.__name__
|
||||
return op_name
|
||||
|
||||
@staticmethod
|
||||
def __construct_debug_hook(previous_op_name=None):
|
||||
"""
|
||||
Wrap debug hook into FuncWrapper.
|
||||
"""
|
||||
inserted_functions = []
|
||||
debug_hook_list = _get_debug_hook_list()
|
||||
if debug_hook_list:
|
||||
for fn in debug_hook_list:
|
||||
new_fn = copy.deepcopy(fn)
|
||||
new_fn.set_previous_op_name(previous_op_name)
|
||||
inserted_func = transforms.py_transforms_util.FuncWrapper(new_fn)
|
||||
inserted_func.implementation = Implementation.PY
|
||||
inserted_functions.append(inserted_func)
|
||||
return inserted_functions
|
||||
|
||||
@staticmethod
|
||||
def __is_debug_hook_op(op):
|
||||
"""
|
||||
Check if the op is user added DebugHook and skip it to avoid changing transforms implementation.
|
||||
"""
|
||||
if isinstance(op, DebugHook):
|
||||
if not get_debug_mode():
|
||||
raise ValueError("It is not allowed to inject DebugHook object in non-debug mode.")
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def __count_pyfuncs(operations):
|
||||
|
@ -3487,6 +3513,19 @@ class MapDataset(UnionBaseDataset):
|
|||
iter_specific_operations.append(op)
|
||||
self.operations = iter_specific_operations
|
||||
|
||||
def __insert_debug_wrapper(self, operations):
|
||||
"""
|
||||
Insert DebuggerWrapper before and after each op if debug mode is on.
|
||||
"""
|
||||
if not get_debug_mode():
|
||||
return operations
|
||||
inserted_operations = self.__construct_debug_hook()
|
||||
for op in operations:
|
||||
inserted_operations.append(op)
|
||||
op_name = self.__parse_op_name(op)
|
||||
inserted_operations.extend(self.__construct_debug_hook(op_name))
|
||||
return inserted_operations
|
||||
|
||||
def __decompose_callable_operations(self):
|
||||
"""
|
||||
Decompose operations and build list of old legacy ops which are callable
|
||||
|
|
|
@ -45,6 +45,7 @@ from mindspore.context import ParallelMode
|
|||
from mindspore.parallel._recovery_context import _set_recovery_context, _get_recovery_context
|
||||
from mindspore.train.dataset_helper import DatasetHelper, connect_network_with_dataset
|
||||
from mindspore.common.api import _pynative_executor
|
||||
from mindspore.dataset.core.config import get_debug_mode
|
||||
from mindspore.dataset.engine.datasets import _set_training_dataset, _reset_training_dataset
|
||||
from mindspore.train import amp
|
||||
|
||||
|
@ -1029,6 +1030,7 @@ class Model:
|
|||
if not dataset_sink_mode and _cache_enable():
|
||||
raise ValueError("Embedding cache mode should run with 'dataset_sink_mode=True'.")
|
||||
|
||||
self._check_sink_mode_for_ds_debug_mode(dataset_sink_mode)
|
||||
|
||||
Validator.check_is_int(sink_size)
|
||||
Validator.check_non_negative_int(epoch)
|
||||
|
@ -1064,6 +1066,12 @@ class Model:
|
|||
if _enable_distributed_mindrt():
|
||||
_reset_op_id_with_offset()
|
||||
|
||||
@staticmethod
|
||||
def _check_sink_mode_for_ds_debug_mode(dataset_sink_mode):
|
||||
if get_debug_mode() and dataset_sink_mode:
|
||||
raise ValueError("Dataset sink mode is not supported when dataset pipeline debug mode is on. "
|
||||
"Please manually turn off sink mode.")
|
||||
|
||||
@staticmethod
|
||||
def _check_methods_for_custom_callbacks(callbacks, current_mode):
|
||||
"""
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
# Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -26,6 +26,7 @@ import mindspore.dataset.engine.iterators as it
|
|||
import mindspore.dataset.transforms
|
||||
import mindspore.dataset.vision as vision
|
||||
import mindspore.dataset.core.config as config
|
||||
import mindspore.dataset.debug as debug
|
||||
from mindspore import log as logger
|
||||
from util import dataset_equal
|
||||
|
||||
|
@ -538,25 +539,35 @@ def test_fast_recovery():
|
|||
assert "set_fast_recovery() missing 1 required positional argument: 'fast_recovery'" in str(error_info.value)
|
||||
|
||||
|
||||
def test_debug_mode():
|
||||
def test_debug_mode_error_case():
|
||||
"""
|
||||
Feature: Test the debug mode setter/getter function
|
||||
Description: This function only accepts a boolean as input and outputs error otherwise
|
||||
Expectation: TypeError will be raised when input argument is missing or is not a boolean
|
||||
Feature: Test the debug mode setter function
|
||||
Description: This function only accepts a boolean as first input and list as second input, outputs error otherwise
|
||||
Expectation: TypeError will be raised when input argument is missing or is invalid
|
||||
"""
|
||||
# set_debug_mode will raise TypeError if input is an integer
|
||||
# set_debug_mode will raise TypeError if first input is an integer
|
||||
config_error_func(ds.config.set_debug_mode, 0, TypeError, "debug_mode_flag isn't of type boolean.")
|
||||
# set_debug_mode will raise TypeError if input is a string
|
||||
# set_debug_mode will raise TypeError if first input is a string
|
||||
config_error_func(ds.config.set_debug_mode, "True", TypeError, "debug_mode_flag isn't of type boolean.")
|
||||
# set_debug_mode will raise TypeError if input is a tuple
|
||||
# set_debug_mode will raise TypeError if first input is a tuple
|
||||
config_error_func(ds.config.set_debug_mode, (True,), TypeError, "debug_mode_flag isn't of type boolean.")
|
||||
# set_debug_mode will raise TypeError if input is None
|
||||
# set_debug_mode will raise TypeError if first input is None
|
||||
config_error_func(ds.config.set_debug_mode, None, TypeError, "debug_mode_flag isn't of type boolean.")
|
||||
# set_debug_mode will raise TypeError if no input is provided
|
||||
with pytest.raises(TypeError) as error_info:
|
||||
ds.config.set_debug_mode()
|
||||
assert "set_debug_mode() missing 1 required positional argument: 'debug_mode_flag'" in str(error_info.value)
|
||||
|
||||
# set_debug_mode will raise TypeError if second input is not valid
|
||||
with pytest.raises(TypeError) as error_info:
|
||||
ds.config.set_debug_mode(True, debug.PrintDataHook())
|
||||
assert "debug_hook_list is not a list" in str(error_info.value)
|
||||
def func():
|
||||
pass
|
||||
with pytest.raises(TypeError) as error_info:
|
||||
ds.config.set_debug_mode(True, [func])
|
||||
assert "All items in debug_hook_list must be of type DebugHook" in str(error_info.value)
|
||||
|
||||
|
||||
def test_error_samples_mode():
|
||||
"""
|
||||
|
@ -608,5 +619,5 @@ if __name__ == '__main__':
|
|||
test_multiprocessing_timeout_interval()
|
||||
test_config_bool_type_error()
|
||||
test_fast_recovery()
|
||||
test_debug_mode()
|
||||
test_debug_mode_error_case()
|
||||
test_error_samples_mode()
|
||||
|
|
|
@ -21,8 +21,10 @@ import pytest
|
|||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms as transforms
|
||||
import mindspore.dataset.vision as vision
|
||||
import mindspore.nn as nn
|
||||
from mindspore.dataset.vision import Inter
|
||||
from mindspore import log as logger
|
||||
from mindspore.train import Model
|
||||
|
||||
# Need to run all these tests in separate processes since
|
||||
# the global configuration setting of debug_mode may impact other tests running in parallel.
|
||||
|
@ -255,9 +257,6 @@ def test_pipeline_debug_mode_generator_pipeline():
|
|||
Expectation: Output is equal to the expected output
|
||||
"""
|
||||
logger.info("test_pipeline_debug_mode_generator_pipeline")
|
||||
# Note: set seed to make sure consistent results of Shuffle op. Even in debug mode, seed has
|
||||
# been set internally in IR pre-pass, results are still random (needs further investigation).
|
||||
ds.set_seed(8)
|
||||
ds1 = ds.GeneratorDataset(generator_md, ["data"])
|
||||
|
||||
# Here ds1 should be [2, 3, 4, 5, 6, 7, 8, 9]
|
||||
|
@ -267,6 +266,7 @@ def test_pipeline_debug_mode_generator_pipeline():
|
|||
ds1 = ds1.take(7)
|
||||
|
||||
# do shuffle followed by batch
|
||||
# Note: Since the internal seed is set in debug mode, the consistency of results after ShuffleOp can be ensured.
|
||||
ds1 = ds1.shuffle(5)
|
||||
ds1 = ds1.batch(3, drop_remainder=True)
|
||||
|
||||
|
@ -274,7 +274,7 @@ def test_pipeline_debug_mode_generator_pipeline():
|
|||
for data in ds1.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
buf.append(data[0])
|
||||
assert len(buf) == 2
|
||||
out_expect = [[[6], [3], [8]], [[4], [7], [2]]]
|
||||
out_expect = [[[5], [4], [2]], [[7], [8], [3]]]
|
||||
np.testing.assert_array_equal(buf, out_expect)
|
||||
|
||||
|
||||
|
@ -443,15 +443,13 @@ def test_pipeline_debug_mode_shuffle():
|
|||
Expectation: Successful.
|
||||
"""
|
||||
logger.info("test_pipeline_debug_mode_shuffle")
|
||||
# Note: set seed to make sure consistent results of Shuffle op. Even in debug mode, seed has
|
||||
# been set internally in IR pre-pass, results are still random (needs further investigation).
|
||||
ds.set_seed(150)
|
||||
|
||||
buffer_size = 5
|
||||
data = ds.TextFileDataset(TEXTFILE_DATA, shuffle=False)
|
||||
# Note: Since the internal seed is set in debug mode, the consistency of results after ShuffleOp can be ensured.
|
||||
data = data.shuffle(buffer_size=buffer_size)
|
||||
out_expect = ["Good luck to everyone.", "Be happy every day.", "This is a text file.",
|
||||
"Another file.", "End of file."]
|
||||
out_expect = ["End of file.", "Be happy every day.", "This is a text file.", "Good luck to everyone.",
|
||||
"Another file."]
|
||||
num_rows = 0
|
||||
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
assert item["text"] == out_expect[num_rows]
|
||||
|
@ -782,7 +780,7 @@ def test_pipeline_debug_mode_batch_map_get_epoch_batch_num():
|
|||
|
||||
def test_pipeline_debug_mode_batch_map_get_epoch_num():
|
||||
"""
|
||||
Feature: Batch op
|
||||
Feature: Dataset Debug Mode
|
||||
Description: Test basic map Batch op with per_batch_map function calling get_epoch_num()
|
||||
Expectation: Output is equal to the expected output
|
||||
"""
|
||||
|
@ -841,6 +839,26 @@ def test_pipeline_debug_mode_batch_map_get_epoch_num():
|
|||
[[0], [-1]], [[-2], [-3]], [[0], [-1]], [[-2], [-3]]])
|
||||
|
||||
|
||||
def test_pipeline_debug_mode_dataset_sink_not_support():
|
||||
"""
|
||||
Feature: Dataset Debug Mode
|
||||
Description: Test dataset sink mode with debug mode enabled.
|
||||
Expectation: Raise ValueError and give proper log.
|
||||
"""
|
||||
dataset = ds.Cifar100Dataset("../data/dataset/testCifar100Data", num_samples=100)
|
||||
def create_model():
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x
|
||||
net = Net()
|
||||
return Model(net)
|
||||
model = create_model()
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
model.train(2, dataset, dataset_sink_mode=True)
|
||||
assert "Dataset sink mode is not supported when dataset pipeline debug mode is on. "\
|
||||
"Please manually turn off sink mode" in str(error_info.value)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
setup_function()
|
||||
test_pipeline_debug_mode_tuple()
|
||||
|
@ -871,4 +889,5 @@ if __name__ == '__main__':
|
|||
test_pipeline_debug_mode_cifar100_per_batch_map_mp()
|
||||
test_pipeline_debug_mode_batch_map_get_epoch_batch_num()
|
||||
test_pipeline_debug_mode_batch_map_get_epoch_num()
|
||||
test_pipeline_debug_mode_dataset_sink_not_support()
|
||||
teardown_function()
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Test debug hook of debug mode
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision as vision
|
||||
import mindspore.dataset.debug as dbg
|
||||
|
||||
# Need to run all these tests in separate processes since
|
||||
# the global configuration setting of debug_mode may impact other tests running in parallel.
|
||||
pytestmark = pytest.mark.forked
|
||||
|
||||
|
||||
@pytest.mark.parametrize("debug_mode_flag, debug_hook_list",
|
||||
[(True, [dbg.PrintMetaDataHook()]),
|
||||
(True, [dbg.PrintDataHook()]),
|
||||
(True, [])])
|
||||
def test_debug_mode_hook(debug_mode_flag, debug_hook_list):
|
||||
"""
|
||||
Feature: Test the debug mode setter function
|
||||
Description: Test valid debug hook case for debug mode
|
||||
Expectation: Success
|
||||
"""
|
||||
# get original configs to restore after running is done.
|
||||
origin_debug_mode = ds.config.get_debug_mode()
|
||||
origin_seed = ds.config.get_seed()
|
||||
|
||||
# set debug flag and hook
|
||||
ds.config.set_debug_mode(debug_mode_flag=debug_mode_flag, debug_hook_list=debug_hook_list)
|
||||
dataset = ds.ImageFolderDataset("../data/dataset/testPK/data", num_samples=5)
|
||||
dataset = dataset.map(operations=[vision.Decode(False), vision.CenterCrop((225, 225))])
|
||||
for _ in dataset.create_dict_iterator(num_epochs=1):
|
||||
pass
|
||||
# restore configs
|
||||
ds.config.set_debug_mode(origin_debug_mode)
|
||||
ds.config.set_seed(origin_seed)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_debug_mode_hook(True, [dbg.PrintMetaDataHook()])
|
||||
test_debug_mode_hook(True, [dbg.PrintDataHook()])
|
||||
test_debug_mode_hook(True, [])
|
Loading…
Reference in New Issue