forked from mindspore-Ecosystem/mindspore
!27745 [MD][Offload] Move offload checks and split to the C pipeline
Merge pull request !27745 from markuskunej/offload_column_pipeline
This commit is contained in:
commit
f007b8a99f
|
@ -193,7 +193,7 @@ PYBIND_REGISTER(MapNode, 2, ([](const py::module *m) {
|
|||
.def(py::init([](std::shared_ptr<DatasetNode> self, py::list operations, py::list input_columns,
|
||||
py::list output_columns, py::list project_columns,
|
||||
std::vector<std::shared_ptr<PyDSCallback>> py_callbacks, int64_t max_rowsize,
|
||||
bool offload) {
|
||||
ManualOffloadMode offload) {
|
||||
auto map = std::make_shared<MapNode>(
|
||||
self, std::move(toTensorOperations(operations)), toStringVector(input_columns),
|
||||
toStringVector(output_columns), toStringVector(project_columns), nullptr,
|
||||
|
@ -297,5 +297,16 @@ PYBIND_REGISTER(ZipNode, 2, ([](const py::module *m) {
|
|||
return zip;
|
||||
}));
|
||||
}));
|
||||
|
||||
// OTHER PYBIND
|
||||
// (alphabetical order)
|
||||
|
||||
PYBIND_REGISTER(ManualOffloadMode, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<ManualOffloadMode>(*m, "ManualOffloadMode", py::arithmetic())
|
||||
.value("UNSPECIFIED", ManualOffloadMode::kUnspecified)
|
||||
.value("DISABLED", ManualOffloadMode::kDisabled)
|
||||
.value("ENABLED", ManualOffloadMode::kEnabled)
|
||||
.export_values();
|
||||
}));
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -35,7 +35,7 @@ namespace dataset {
|
|||
MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations,
|
||||
std::vector<std::string> input_columns, std::vector<std::string> output_columns,
|
||||
const std::vector<std::string> &project_columns, std::shared_ptr<DatasetCache> cache,
|
||||
std::vector<std::shared_ptr<DSCallback>> callbacks, bool offload)
|
||||
std::vector<std::shared_ptr<DSCallback>> callbacks, ManualOffloadMode offload)
|
||||
: operations_(operations),
|
||||
input_columns_(input_columns),
|
||||
output_columns_(output_columns),
|
||||
|
@ -150,7 +150,7 @@ void MapNode::setOperations(const std::vector<std::shared_ptr<TensorOperation>>
|
|||
}
|
||||
std::vector<std::shared_ptr<TensorOperation>> MapNode::operations() { return operations_; }
|
||||
|
||||
void MapNode::SetOffload(bool offload) { offload_ = offload; }
|
||||
void MapNode::SetOffload(ManualOffloadMode offload) { offload_ = offload; }
|
||||
|
||||
Status MapNode::to_json(nlohmann::json *out_json) {
|
||||
RETURN_UNEXPECTED_IF_NULL(out_json);
|
||||
|
|
|
@ -32,7 +32,8 @@ class MapNode : public DatasetNode {
|
|||
MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations,
|
||||
std::vector<std::string> input_columns = {}, std::vector<std::string> output_columns = {},
|
||||
const std::vector<std::string> &columns = {}, std::shared_ptr<DatasetCache> cache = nullptr,
|
||||
std::vector<std::shared_ptr<DSCallback>> callbacks = {}, bool offload = false);
|
||||
std::vector<std::shared_ptr<DSCallback>> callbacks = {},
|
||||
ManualOffloadMode offload = ManualOffloadMode::kUnspecified);
|
||||
|
||||
/// \brief Destructor
|
||||
~MapNode() = default;
|
||||
|
@ -87,10 +88,10 @@ class MapNode : public DatasetNode {
|
|||
const std::vector<std::string> &OutputColumns() const { return output_columns_; }
|
||||
const std::vector<std::string> &ProjectColumns() const { return project_columns_; }
|
||||
const std::vector<std::shared_ptr<DSCallback>> &Callbacks() const { return callbacks_; }
|
||||
bool GetOffload() const { return offload_; }
|
||||
ManualOffloadMode GetOffload() const { return offload_; }
|
||||
|
||||
/// \brief setter to set offload flag of node
|
||||
void SetOffload(bool offload);
|
||||
void SetOffload(ManualOffloadMode offload);
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
|
@ -123,8 +124,8 @@ class MapNode : public DatasetNode {
|
|||
std::vector<std::string> project_columns_;
|
||||
std::vector<std::shared_ptr<DSCallback>> callbacks_;
|
||||
|
||||
/// \brief Flag to indicate whether offload is set for the Map node.
|
||||
bool offload_;
|
||||
/// \brief ManualOffloadMode to indicate manual_offload status
|
||||
ManualOffloadMode offload_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -13,20 +13,29 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <string>
|
||||
|
||||
#include "minddata/dataset/engine/opt/pre/node_offload_pass.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
NodeOffloadPass::OffloadNodes::OffloadNodes() : prev_map_offloaded_(true) {}
|
||||
NodeOffloadPass::OffloadNodes::OffloadNodes() : auto_offload_(GlobalContext::config_manager()->get_auto_offload()) {}
|
||||
|
||||
// Perform MapNode offload check.
|
||||
Status NodeOffloadPass::OffloadNodes::Visit(std::shared_ptr<MapNode> node, bool *const modified) {
|
||||
*modified = false;
|
||||
// Check if this node is set to offload and add to nodes_to_offload_.
|
||||
if (node->GetOffload() == true) {
|
||||
ManualOffloadMode manual_offload = node->GetOffload();
|
||||
bool offload_successful = false;
|
||||
std::vector<std::string> input_columns = node->InputColumns();
|
||||
|
||||
// Check if the node is set to manually offload, or if auto_offload is enabled while manual offload is not False.
|
||||
if ((manual_offload == ManualOffloadMode::kEnabled) ||
|
||||
((auto_offload_ == true) && (manual_offload != ManualOffloadMode::kDisabled))) {
|
||||
bool offload_supported = true;
|
||||
if (IS_OUTPUT_ON(mindspore::INFO)) {
|
||||
std::string operations = "operations=[";
|
||||
auto op_list = node->operations();
|
||||
|
@ -40,16 +49,86 @@ Status NodeOffloadPass::OffloadNodes::Visit(std::shared_ptr<MapNode> node, bool
|
|||
operations += "]";
|
||||
MS_LOG(INFO) << "The offload of map(" + operations + ") is true, and heterogeneous acceleration will be enabled.";
|
||||
}
|
||||
if (prev_map_offloaded_) {
|
||||
nodes_to_offload_.push_back(std::static_pointer_cast<DatasetNode>(node));
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Invalid use of offload in map, ignoring offload flag. Ops will be run in CPU pipeline";
|
||||
node->SetOffload(false);
|
||||
*modified = true;
|
||||
// Currently offload not supported for different output_columns.
|
||||
if (input_columns != node->OutputColumns()) {
|
||||
MS_LOG(WARNING) << "Cannot offload map operation with output_columns != input_columns. Turning offload off.";
|
||||
offload_supported = false;
|
||||
}
|
||||
|
||||
// Check if map operation is at the end of the pipeline.
|
||||
for (std::string input_column : input_columns) {
|
||||
if (end_of_pipeline_.find(input_column) != end_of_pipeline_.end()) {
|
||||
// The input column has already appeared in a previous map op.
|
||||
if (end_of_pipeline_[input_column] == false) {
|
||||
MS_LOG(WARNING) << "Map operation is not at the end of the pipeline for the following input column: "
|
||||
<< input_column << ". Turning offload off.";
|
||||
offload_supported = false;
|
||||
}
|
||||
} else {
|
||||
// First time seeing input column in a Map Node, add input column to map object.
|
||||
end_of_pipeline_[input_column] = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (offload_supported) {
|
||||
std::vector<std::string> invalid_ops;
|
||||
std::vector<std::shared_ptr<TensorOperation>> temp_operations = node->operations();
|
||||
bool all_valid_ops = true;
|
||||
int last_invalid_op_pos = 1;
|
||||
int pos = 1;
|
||||
|
||||
// Check individual operations to see if they are supported by offload.
|
||||
for (auto operation : temp_operations) {
|
||||
std::string op_name = operation->Name();
|
||||
if (supported_ops_.find(op_name) == supported_ops_.end()) {
|
||||
last_invalid_op_pos = pos;
|
||||
invalid_ops.push_back(op_name);
|
||||
all_valid_ops = false;
|
||||
}
|
||||
pos++;
|
||||
}
|
||||
|
||||
if (all_valid_ops) {
|
||||
// All operations can be offloaded.
|
||||
nodes_to_offload_.push_back(std::static_pointer_cast<DatasetNode>(node));
|
||||
offload_successful = true;
|
||||
} else {
|
||||
// Some operation(s) cannot be offloaded.
|
||||
MS_LOG(WARNING)
|
||||
<< "In Map Node, offload is set to True, but offload is not supported by the following operation(s): "
|
||||
<< invalid_ops;
|
||||
|
||||
// See if the operations can be split into two Map Nodes
|
||||
if (last_invalid_op_pos != temp_operations.size()) {
|
||||
MS_LOG(WARNING) << "Map operation will be split after " << invalid_ops.back()
|
||||
<< ", with the second map operation being offloaded.";
|
||||
std::vector<std::shared_ptr<TensorOperation>> non_offload_ops(temp_operations.begin(),
|
||||
temp_operations.begin() + last_invalid_op_pos);
|
||||
std::vector<std::shared_ptr<TensorOperation>> offload_ops(temp_operations.begin() + last_invalid_op_pos,
|
||||
temp_operations.end());
|
||||
|
||||
// First set operations to offload_ops to prepare for copy
|
||||
node->setOperations(offload_ops);
|
||||
// Copy node (returns a copy of the node, but without children)
|
||||
std::shared_ptr<DatasetNode> offload_node = node->Copy();
|
||||
// Set the number of parallel workers of the new node to be the same as current one.
|
||||
offload_node->SetNumWorkers(node->NumWorkers());
|
||||
node->setOperations(non_offload_ops);
|
||||
// Insert the split offload map node above the original map node in the ir tree.
|
||||
node->InsertAbove(offload_node);
|
||||
// Add the offload map node to nodes_to_offload
|
||||
nodes_to_offload_.push_back(offload_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!offload_successful) {
|
||||
// Offload of the original node without modification did not take place.
|
||||
// Since map nodes are visited in reverse order, no other map ops for the input_column(s) can be offloaded after
|
||||
// this.
|
||||
for (std::string input_column : input_columns) {
|
||||
end_of_pipeline_[input_column] = false;
|
||||
}
|
||||
} else {
|
||||
// Since map nodes are visited in reverse order, no other map ops can be offloaded after this.
|
||||
prev_map_offloaded_ = false;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -17,7 +17,10 @@
|
|||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_OFFLOAD_PASS_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_OFFLOAD_PASS_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
|
||||
|
@ -49,8 +52,16 @@ class NodeOffloadPass : public IRTreePass {
|
|||
std::vector<std::shared_ptr<DatasetNode>> nodes_to_offload() { return nodes_to_offload_; }
|
||||
|
||||
private:
|
||||
/// \brief Vector of nodes to offload
|
||||
std::vector<std::shared_ptr<DatasetNode>> nodes_to_offload_;
|
||||
bool prev_map_offloaded_;
|
||||
/// \brief Vector of supported offload operations
|
||||
const std::set<std::string> supported_ops_{
|
||||
"HwcToChw", "Normalize", "RandomColorAdjust", "RandomHorizontalFlip", "RandomSharpness",
|
||||
"RandomVerticalFlip", "Rescale"};
|
||||
/// \brief std::map indicating if the map op for the input column is at the end of the pipeline
|
||||
std::map<std::string, bool> end_of_pipeline_;
|
||||
/// \brief bool indicating whether the auto_offload config option is enabled
|
||||
bool auto_offload_;
|
||||
};
|
||||
|
||||
public:
|
||||
|
|
|
@ -84,6 +84,13 @@ enum class MS_API NormMode {
|
|||
kOrtho = 1 ///< Ortho type norm.
|
||||
};
|
||||
|
||||
/// \brief The mode for manual offload.
|
||||
enum class MS_API ManualOffloadMode {
|
||||
kUnspecified, ///< Not set, will use auto_offload setting instead.
|
||||
kDisabled, ///< Do not perform offload.
|
||||
kEnabled ///< Attempt to offload.
|
||||
};
|
||||
|
||||
/// \brief Target devices to perform map operation.
|
||||
enum class MS_API MapTargetDevice {
|
||||
kCpu, ///< CPU Device.
|
||||
|
|
|
@ -52,7 +52,7 @@ from mindspore.common import Tensor
|
|||
from mindspore import log as logger
|
||||
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
|
||||
from mindspore.parallel._utils import _get_device_num
|
||||
from mindspore.dataset.engine.offload import GetOffloadModel, op_to_model
|
||||
from mindspore.dataset.engine.offload import GetOffloadModel
|
||||
|
||||
import mindspore.dataset.transforms.py_transforms as py_transforms
|
||||
|
||||
|
@ -74,7 +74,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
|||
check_stl10_dataset, check_yelp_review_dataset, check_penn_treebank_dataset, check_iwslt2016_dataset, \
|
||||
check_iwslt2017_dataset, check_sogou_news_dataset, check_yahoo_answers_dataset, check_udpos_dataset
|
||||
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
|
||||
get_prefetch_size, get_auto_offload
|
||||
get_prefetch_size
|
||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||
from ..core.validator_helpers import replace_none
|
||||
from ..core.py_util_helpers import ExceptionHandler
|
||||
|
@ -89,6 +89,13 @@ if platform.system().lower() == "darwin" and multiprocessing.get_start_method()
|
|||
multiprocessing.set_start_method("fork", True)
|
||||
|
||||
|
||||
OffloadToManualOffloadMode = {
|
||||
None: cde.ManualOffloadMode.UNSPECIFIED,
|
||||
False: cde.ManualOffloadMode.DISABLED,
|
||||
True: cde.ManualOffloadMode.ENABLED
|
||||
}
|
||||
|
||||
|
||||
class Shuffle(str, Enum):
|
||||
GLOBAL: str = "global"
|
||||
FILES: str = "files"
|
||||
|
@ -100,95 +107,6 @@ ShuffleToShuffleMode = {Shuffle.FILES: cde.ShuffleMode.FILES,
|
|||
Shuffle.INFILE: cde.ShuffleMode.INFILE}
|
||||
|
||||
|
||||
def get_offloadable_ops(operations):
|
||||
"""
|
||||
Check if operations are supported by offload hardware accelerator.
|
||||
|
||||
Args:
|
||||
operations: list of operations.
|
||||
|
||||
Returns:
|
||||
Dictionary with boolean key for each operation for offload support.
|
||||
"""
|
||||
is_offloadable = {}
|
||||
if not isinstance(operations, list):
|
||||
operations = [operations]
|
||||
for op in operations:
|
||||
name = op.__class__.__name__
|
||||
if name in op_to_model:
|
||||
is_offloadable[name] = True
|
||||
else:
|
||||
is_offloadable[name] = False
|
||||
|
||||
return is_offloadable
|
||||
|
||||
|
||||
def check_offload_map(operations, output_columns):
|
||||
"""
|
||||
Check if operations are supported by offload hardware accelerator. If not, see if list of operations can be split
|
||||
into two: not offload supported and offload supported
|
||||
|
||||
Args:
|
||||
operations: list of operations.
|
||||
output_columns: list of names assigned to the columns outputted by the last operation.
|
||||
|
||||
Returns:
|
||||
bool, indicates whether to use offload hardware accelarator.
|
||||
bool, indicates whether list of map operations can be split.
|
||||
list, first group of non-offload supported operations.
|
||||
list, second group of offload supported operations.
|
||||
"""
|
||||
offloadable_ops = get_offloadable_ops(operations)
|
||||
offload = True
|
||||
can_split = False
|
||||
offload_ops = []
|
||||
non_offload_ops = []
|
||||
invalid_ops = []
|
||||
for op in offloadable_ops:
|
||||
if offloadable_ops[op] is not True:
|
||||
offload = False
|
||||
invalid_ops.append(op)
|
||||
if not offload:
|
||||
logger.warning(("In map(), offload is set to True, but offload is not supported for the following "
|
||||
"operation(s): {}").format(*invalid_ops))
|
||||
|
||||
if output_columns:
|
||||
# Cannot split (currently), unsure which side of operations would alter the output columns
|
||||
logger.warning("Since output_columns is specified, the list of operations cannot be split. "
|
||||
"Unsure which operation(s) alter the columns. Setting offload to False.")
|
||||
|
||||
else:
|
||||
# See if the map operator can be split and then offloaded
|
||||
size = len(offloadable_ops)
|
||||
idx = size
|
||||
split_idx = size
|
||||
op_names = list(offloadable_ops.keys())
|
||||
for op_name in reversed(op_names):
|
||||
if not offloadable_ops[op_name]:
|
||||
# From reverse order, this op cannot be offloaded, therefore split here.
|
||||
split_idx = idx
|
||||
break
|
||||
idx = idx - 1
|
||||
|
||||
if split_idx == size:
|
||||
# The last op in the list cannot be offloaded, therefore nothing can be offloaded.
|
||||
# Nothing to split.
|
||||
logger.warning(("The last operation, {}, is not supported by offload, setting offload"
|
||||
" to False").format(op_names[split_idx - 1]))
|
||||
|
||||
elif split_idx != 0:
|
||||
# There are at least 1 offloadable ops at the end of the list.
|
||||
# Split map() after the last non-offloadable op and only offload the second list of operations.
|
||||
can_split = True
|
||||
non_offload_ops = operations[:split_idx]
|
||||
offload_ops = operations[split_idx:]
|
||||
logger.warning(("The list of operations in map() can be split into two: {}, {}\n"
|
||||
"The second list of operations will be run with offload=True"
|
||||
).format(op_names[:split_idx], op_names[split_idx:]))
|
||||
|
||||
return offload, can_split, non_offload_ops, offload_ops
|
||||
|
||||
|
||||
def shuffle_to_shuffle_mode(shuffle):
|
||||
"""
|
||||
Shuffle Enum to Shuffle Mode
|
||||
|
@ -887,28 +805,8 @@ class Dataset:
|
|||
... output_columns=["mod2", "mod3", "mod5", "mod7"],
|
||||
... column_order=["mod7", "mod3", "col2"])
|
||||
"""
|
||||
can_split = False
|
||||
non_offload_ops = []
|
||||
offload_ops = []
|
||||
if offload is not None:
|
||||
offload_flag = offload
|
||||
else:
|
||||
offload_flag = get_auto_offload()
|
||||
|
||||
if offload_flag:
|
||||
offload_flag, can_split, non_offload_ops, offload_ops = check_offload_map(operations, output_columns)
|
||||
|
||||
if can_split:
|
||||
non_offload_map_ds = MapDataset(self, non_offload_ops, input_columns, output_columns, column_order,
|
||||
num_parallel_workers, python_multiprocessing, cache, callbacks,
|
||||
max_rowsize, offload=False)
|
||||
|
||||
return MapDataset(non_offload_map_ds, offload_ops, input_columns, output_columns, column_order,
|
||||
num_parallel_workers, python_multiprocessing, cache, callbacks, max_rowsize,
|
||||
offload=True)
|
||||
|
||||
return MapDataset(self, operations, input_columns, output_columns, column_order, num_parallel_workers,
|
||||
python_multiprocessing, cache, callbacks, max_rowsize, offload_flag)
|
||||
python_multiprocessing, cache, callbacks, max_rowsize, offload)
|
||||
|
||||
@check_filter
|
||||
def filter(self, predicate, input_columns=None, num_parallel_workers=None):
|
||||
|
@ -2900,7 +2798,7 @@ class MapDataset(TextBaseDataset, Dataset):
|
|||
callbacks (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called (Default=None)
|
||||
max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy
|
||||
data between processes. This is only used if python_multiprocessing is set to True (default=16).
|
||||
offload (bool, optional): Flag to indicate whether offload is used (Default=False).
|
||||
offload (bool, optional): Flag to indicate whether offload is used (Default=None).
|
||||
|
||||
Raises:
|
||||
ValueError: If len(input_columns) != len(output_columns) and column_order is not specified.
|
||||
|
@ -2908,7 +2806,7 @@ class MapDataset(TextBaseDataset, Dataset):
|
|||
|
||||
def __init__(self, input_dataset, operations=None, input_columns=None, output_columns=None, column_order=None,
|
||||
num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None, max_rowsize=16,
|
||||
offload=False):
|
||||
offload=None):
|
||||
super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers, cache=cache)
|
||||
self.operations = to_list(operations)
|
||||
self.operations = py_transforms.Compose.reduce(self.operations)
|
||||
|
@ -2946,7 +2844,7 @@ class MapDataset(TextBaseDataset, Dataset):
|
|||
|
||||
callbacks = [cb.create_runtime_obj() for cb in self.callbacks]
|
||||
return cde.MapNode(children[0], operations, self.input_columns, self.output_columns, self.column_order,
|
||||
callbacks, self.max_rowsize, self.offload)
|
||||
callbacks, self.max_rowsize, OffloadToManualOffloadMode[self.offload])
|
||||
|
||||
def __deepcopy__(self, memodict):
|
||||
return self.__safe_deepcopy__(memodict, exclude=("operations", "callbacks", "__transfer_dataset__"))
|
||||
|
@ -9566,6 +9464,7 @@ class _SVHNDataset:
|
|||
"""
|
||||
Mainly for loading SVHN Dataset, and return two rows each time.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset_dir, usage):
|
||||
self.dataset_dir = os.path.realpath(dataset_dir)
|
||||
self.usage = usage
|
||||
|
|
|
@ -88,8 +88,13 @@ class Iterator:
|
|||
self.__index = 0
|
||||
|
||||
self.offload_model = None
|
||||
if offload.check_map_offload(self.__ori_dataset):
|
||||
self.offload_model = offload.GetOffloadModel(consumer)
|
||||
offload_model = offload.GetOffloadModel(consumer)
|
||||
|
||||
# See if GetOffloadModel identified any operations set to be offloaded.
|
||||
if offload_model.transform_list != []:
|
||||
offload.check_concat_zip_dataset(self.__ori_dataset)
|
||||
self.offload_model = offload_model
|
||||
|
||||
|
||||
ITERATORS_LIST.append(weakref.ref(self))
|
||||
_unset_iterator_cleanup()
|
||||
|
|
|
@ -30,34 +30,11 @@ def check_concat_zip_dataset(dataset):
|
|||
"""
|
||||
while dataset:
|
||||
if len(dataset.children) > 1:
|
||||
return True
|
||||
raise RuntimeError("Offload module currently does not support concatenated or zipped datasets.")
|
||||
if dataset.children:
|
||||
dataset = dataset.children[0]
|
||||
continue
|
||||
dataset = dataset.children
|
||||
return False
|
||||
|
||||
|
||||
def check_map_offload(dataset):
|
||||
"""
|
||||
Check if offload flag is set in data pipeline map ops.
|
||||
"""
|
||||
offload_check = False
|
||||
concat_zip_check = check_concat_zip_dataset(dataset)
|
||||
while dataset:
|
||||
if hasattr(dataset, 'offload'):
|
||||
if dataset.offload is True:
|
||||
offload_check = True
|
||||
break
|
||||
if dataset.children:
|
||||
dataset = dataset.children[0]
|
||||
else:
|
||||
dataset = []
|
||||
|
||||
if offload_check and concat_zip_check:
|
||||
raise RuntimeError("Offload module currently does not support concatenated or zipped datasets.")
|
||||
|
||||
return offload_check
|
||||
|
||||
|
||||
def apply_offload_iterators(data, offload_model):
|
||||
|
|
|
@ -118,13 +118,19 @@ def _generate_network_with_dataset(network, dataset_helper, queue_name):
|
|||
|
||||
|
||||
def _check_add_offload(dataset, dataset_helper, network):
|
||||
"""Check if any map operations were removed to be offloaded and apply the transforms if so."""
|
||||
from mindspore.dataset.engine import offload
|
||||
if offload.check_map_offload(dataset.__transfer_dataset__):
|
||||
if hasattr(dataset, '__no_send__'):
|
||||
# Dataset was not sent to device. Skip adding offload.
|
||||
return network
|
||||
offload_model = dataset.__transfer_dataset__.get_offload_model()
|
||||
# See if the offload pass identified any operations to be offloaded
|
||||
if offload_model.transform_list != []:
|
||||
offload.check_concat_zip_dataset(dataset.__transfer_dataset__)
|
||||
# A temporary solution to ensure there are two columns in dataset.
|
||||
dataset_types, _ = dataset_helper.types_shapes()
|
||||
if len(dataset_types) != 2:
|
||||
raise RuntimeError("Offload can currently only use datasets with two columns.")
|
||||
offload_model = dataset.__transfer_dataset__.get_offload_model()
|
||||
network = offload.ApplyPreTransform(offload_model, network)
|
||||
return network
|
||||
|
||||
|
|
|
@ -16,7 +16,9 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
|
||||
|
||||
DATA_DIR = "../data/dataset/testPK/data"
|
||||
|
@ -53,13 +55,15 @@ def test_auto_offload():
|
|||
"""
|
||||
trans = [C.Decode(), C.HWC2CHW()]
|
||||
|
||||
# Dataset with config.auto_offload not activated
|
||||
# Enable automatic offload
|
||||
ds.config.set_auto_offload(True)
|
||||
|
||||
# Dataset with offload deactivated
|
||||
dataset_auto_disabled = ds.ImageFolderDataset(DATA_DIR)
|
||||
dataset_auto_disabled = dataset_auto_disabled.map(operations=trans, input_columns="image")
|
||||
dataset_auto_disabled = dataset_auto_disabled.map(operations=trans, input_columns="image", offload=False)
|
||||
dataset_auto_disabled = dataset_auto_disabled.batch(8, drop_remainder=True)
|
||||
|
||||
# Dataset with config.auto_offload activated
|
||||
ds.config.set_auto_offload(True)
|
||||
dataset_auto_enabled = ds.ImageFolderDataset(DATA_DIR)
|
||||
dataset_auto_enabled = dataset_auto_enabled.map(operations=trans, input_columns="image")
|
||||
dataset_auto_enabled = dataset_auto_enabled.batch(8, drop_remainder=True)
|
||||
|
@ -179,6 +183,46 @@ def test_offload_rescale_op():
|
|||
np.testing.assert_almost_equal(img_0, img_1, decimal=6)
|
||||
|
||||
|
||||
def test_offload_different_column_end_of_pipeline():
|
||||
"""
|
||||
Feature: Test offload end_of_pipeline check.
|
||||
Description: Input is image dataset.
|
||||
Expectation: The image map op gets offloaded even though it comes before the not-offloaded label map op, since
|
||||
the end_of_pipeline check looks at columns separately.
|
||||
"""
|
||||
image_trans = [C.Decode(), C.HWC2CHW()]
|
||||
ds.config.set_auto_offload(True)
|
||||
|
||||
dataset_0 = ds.ImageFolderDataset(DATA_DIR)
|
||||
dataset_0 = dataset_0.map(operations=image_trans, input_columns="image")
|
||||
dataset_0 = dataset_0.map(operations=[C2.TypeCast(mstype.int32)], input_columns="label", offload=False)
|
||||
|
||||
data_iterator = dataset_0.create_tuple_iterator(num_epochs=1, output_numpy=True)
|
||||
# Assert at least one operation has been offloaded
|
||||
np.testing.assert_(len(data_iterator.offload_model.transform_list[0].me_ops) > 0)
|
||||
|
||||
ds.config.set_auto_offload(False)
|
||||
|
||||
|
||||
def test_offload_not_end_of_pipeline():
|
||||
"""
|
||||
Feature: Test offload end_of_pipeline check.
|
||||
Description: Input is image dataset.
|
||||
Expectation: No operations are offloaded, since the image map op at the end of the pipeline has the
|
||||
offload flag set to False.
|
||||
"""
|
||||
dataset_0 = ds.ImageFolderDataset(DATA_DIR)
|
||||
dataset_0 = dataset_0.map(operations=[C.Decode()], input_columns="image", offload=True)
|
||||
dataset_0 = dataset_0.map(operations=[C.RandomHorizontalFlip(prob=0.5)], input_columns="image", offload=True)
|
||||
dataset_0 = dataset_0.map(operations=[C.HWC2CHW()], input_columns="image", offload=False)
|
||||
|
||||
dataset_0 = dataset_0.map(operations=[C2.TypeCast(mstype.int32)], input_columns="label", offload=False)
|
||||
|
||||
data_iterator = dataset_0.create_tuple_iterator(num_epochs=1, output_numpy=True)
|
||||
# Assert no operations are set to be offloaded
|
||||
np.testing.assert_(data_iterator.offload_model is None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_offload()
|
||||
test_auto_offload()
|
||||
|
@ -186,3 +230,5 @@ if __name__ == "__main__":
|
|||
test_offload_concat_dataset_2()
|
||||
test_offload_normalize_op()
|
||||
test_offload_rescale_op()
|
||||
test_offload_different_column_end_of_pipeline()
|
||||
test_offload_not_end_of_pipeline()
|
||||
|
|
Loading…
Reference in New Issue