diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/consumer/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/consumer/bindings.cc index e59a3a400ac..2d72adca9c9 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/consumer/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/consumer/bindings.cc @@ -38,6 +38,7 @@ PYBIND_REGISTER(PythonIteratorConsumer, 1, ([](const py::module *m) { THROW_IF_ERROR(self.GetNextAsDict(&output)); return output; }) + .def("GetOffload", [](PythonIteratorConsumer &self) { return self.GetOffload(); }) .def("GetNextAsList", [](PythonIteratorConsumer &self) { py::list output; THROW_IF_ERROR(self.GetNextAsList(&output)); @@ -123,6 +124,7 @@ PYBIND_REGISTER(ToDevice, 1, ([](const py::module *m) { .def("Send", [](ToDevice &self) { THROW_IF_ERROR(self.Send()); }) .def("ContinueSend", [](ToDevice &self) { THROW_IF_ERROR(self.Continue()); }) .def("StopSend", [](ToDevice &self) { THROW_IF_ERROR(self.Stop()); }) + .def("GetOffload", [](ToDevice &self) { return self.GetOffload(); }) .def("GetDataInfo", [](ToDevice &self) { std::vector types_c; @@ -170,6 +172,5 @@ PYBIND_REGISTER(PythonDatasetSizeGetter, 1, ([](const py::module *m) { return size; }); })); - } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/datasetops/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/datasetops/bindings.cc index cc486b10336..553ad9d7305 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/datasetops/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/datasetops/bindings.cc @@ -51,7 +51,6 @@ namespace mindspore { namespace dataset { - PYBIND_REGISTER(DatasetNode, 1, ([](const py::module *m) { (void)py::class_>(*m, "Dataset") .def("set_num_workers", @@ -193,11 +192,12 @@ PYBIND_REGISTER(MapNode, 2, ([](const py::module *m) { (void)py::class_>(*m, "MapNode", "to create a MapNode") .def(py::init([](std::shared_ptr self, py::list operations, py::list input_columns, py::list output_columns, py::list project_columns, - std::vector> py_callbacks) { + std::vector> py_callbacks, int64_t max_rowsize, + bool offload) { auto map = std::make_shared( self, std::move(toTensorOperations(operations)), toStringVector(input_columns), toStringVector(output_columns), toStringVector(project_columns), nullptr, - std::vector>(py_callbacks.begin(), py_callbacks.end())); + std::vector>(py_callbacks.begin(), py_callbacks.end()), offload); THROW_IF_ERROR(map->ValidateParams()); return map; })); @@ -297,6 +297,5 @@ PYBIND_REGISTER(ZipNode, 2, ([](const py::module *m) { return zip; })); })); - } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc index 090d8a1a69f..ce84b13cbf8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc @@ -98,6 +98,8 @@ Status TreeConsumer::RegisterProfilingManager() { } #endif +std::string TreeConsumer::GetOffload() { return (tree_adapter_->GetOffloadJson()).dump(); } + // IteratorConsumer Status IteratorConsumer::Init(std::shared_ptr d) { RETURN_IF_NOT_OK(tree_adapter_->Compile(std::move(d), num_epochs_)); diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h index 36a5245c4d3..d2e70456cb3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h @@ -50,6 +50,10 @@ class TreeConsumer { /// \return Status error code virtual Status Terminate(); + /// Function for all consumers to get the offload JSON string. + /// \return Offload JSON string. + std::string GetOffload(); + #ifndef ENABLE_SECURITY virtual Status RegisterProfilingManager(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc index 401b1dff0f9..d965851c511 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc @@ -35,20 +35,21 @@ namespace dataset { MapNode::MapNode(std::shared_ptr child, std::vector> operations, std::vector input_columns, std::vector output_columns, const std::vector &project_columns, std::shared_ptr cache, - std::vector> callbacks) + std::vector> callbacks, bool offload) : operations_(operations), input_columns_(input_columns), output_columns_(output_columns), project_columns_(project_columns), DatasetNode(std::move(cache)), - callbacks_(callbacks) { + callbacks_(callbacks), + offload_(offload) { this->AddChild(child); } std::shared_ptr MapNode::Copy() { std::vector> operations = operations_; auto node = std::make_shared(nullptr, operations, input_columns_, output_columns_, project_columns_, cache_, - callbacks_); + callbacks_, offload_); return node; } @@ -151,6 +152,8 @@ void MapNode::setOperations(const std::vector> } std::vector> MapNode::operations() { return operations_; } +void MapNode::SetOffload(bool offload) { offload_ = offload; } + Status MapNode::to_json(nlohmann::json *out_json) { RETURN_UNEXPECTED_IF_NULL(out_json); nlohmann::json args; @@ -182,6 +185,7 @@ Status MapNode::to_json(nlohmann::json *out_json) { (void)std::transform(callbacks_.begin(), callbacks_.end(), std::back_inserter(cbs), [](std::shared_ptr cb) -> int32_t { return cb != nullptr ? cb->step_size() : 0; }); args["callback"] = cbs; + *out_json = args; return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h index 511f7e0e2cf..721427f5a7b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h @@ -32,7 +32,7 @@ class MapNode : public DatasetNode { MapNode(std::shared_ptr child, std::vector> operations, std::vector input_columns = {}, std::vector output_columns = {}, const std::vector &columns = {}, std::shared_ptr cache = nullptr, - std::vector> callbacks = {}); + std::vector> callbacks = {}, bool offload = false); /// \brief Destructor ~MapNode() = default; @@ -87,6 +87,10 @@ class MapNode : public DatasetNode { const std::vector &OutputColumns() const { return output_columns_; } const std::vector &ProjectColumns() const { return project_columns_; } const std::vector> &Callbacks() const { return callbacks_; } + bool GetOffload() const { return offload_; } + + /// \brief setter to set offload flag of node + void SetOffload(bool offload); /// \brief Get the arguments of node /// \param[out] out_json JSON string of all attributes @@ -118,6 +122,9 @@ class MapNode : public DatasetNode { std::vector output_columns_; std::vector project_columns_; std::vector> callbacks_; + + /// \brief Flag to indicate whether offload is set for the Map node. + bool offload_; }; } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt index 78bf0fc4a35..068429d696e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt @@ -12,6 +12,7 @@ set(DATASET_ENGINE_OPT_SRC_FILES pre/epoch_ctrl_pass.cc pre/getter_pass.cc pre/input_validation_pass.cc + pre/node_offload_pass.cc pre/node_removal_pass.cc ) diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_offload_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_offload_pass.cc new file mode 100644 index 00000000000..e1decca469e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_offload_pass.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/opt/pre/node_offload_pass.h" +#include "minddata/dataset/engine/ir/datasetops/map_node.h" +#include "minddata/dataset/engine/ir/datasetops/batch_node.h" + +namespace mindspore { +namespace dataset { +NodeOffloadPass::OffloadNodes::OffloadNodes() : prev_map_offloaded_(true) {} + +// Perform MapNode offload check. +Status NodeOffloadPass::OffloadNodes::Visit(std::shared_ptr node, bool *const modified) { + *modified = false; + // Check if this node is set to offload and add to nodes_to_offload_. + if (node->GetOffload() == true) { + MS_LOG(INFO) << "Pre pass: node offload of map class is true."; + if (prev_map_offloaded_) { + nodes_to_offload_.push_back(std::static_pointer_cast(node)); + } else { + MS_LOG(WARNING) << "Invalid use of offload in map, ignoring offload flag. Ops will be run in CPU pipeline"; + node->SetOffload(false); + *modified = true; + } + } else { + // Since map nodes are visited in reverse order, no other map ops can be offloaded after this. + prev_map_offloaded_ = false; + } + return Status::OK(); +} + +// constructor +NodeOffloadPass::NodeOffloadPass() {} + +// Walk the tree to collect the nodes to offload, fill the offload_json object, then remove the node. +Status NodeOffloadPass::RunOnTree(std::shared_ptr root_ir, bool *const modified) { + MS_LOG(INFO) << "Pre pass: node offload pass started."; + // Create the offload node pass which can identify which nodes need to be offloaded. + std::unique_ptr offload_nodes = std::make_unique(); + RETURN_IF_NOT_OK(offload_nodes->Run(root_ir, modified)); + + // Update modified flag if there were any nodes identified to be offloaded + if (offload_nodes->nodes_to_offload().empty() == false) { + *modified = true; + } + + // Then, execute the offloading of any nodes that were set up to be offloaded + for (auto node : offload_nodes->nodes_to_offload()) { + RETURN_IF_NOT_OK(node->to_json(&offload_json_)); + offload_json_["op_type"] = node->Name(); + + // Add the single offloaded node to the list of offloaded nodes and remove the node from the ir tree + offload_json_list_.push_back(offload_json_); + RETURN_IF_NOT_OK(node->Drop()); + } + MS_LOG(INFO) << "Pre pass: offload node removal pass complete."; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_offload_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_offload_pass.h new file mode 100644 index 00000000000..a70c47ac982 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_offload_pass.h @@ -0,0 +1,82 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_OFFLOAD_PASS_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_OFFLOAD_PASS_H_ + +#include +#include +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { +class DatasetOp; + +/// \class NodeOffloadPass +/// \brief This is a tree pass that will offload nodes. It uses offload_nodes to first identify which +/// nodes should be offloaded, adds the nodes' namea to the offload list, then removes the nodes from the ir tree. +class NodeOffloadPass : public IRTreePass { + /// \class OffloadNodes + /// \brief This is a NodePass whose job is to identify which nodes should be offloaded. + class OffloadNodes : public IRNodePass { + public: + /// \brief Constructor + OffloadNodes(); + /// \brief Destructor + ~OffloadNodes() = default; + + /// \brief Perform MapNode offload check + /// \param[in] node The node being visited + /// \param[in, out] modified Indicator if the node was changed at all + /// \return Status The status code returned + Status Visit(std::shared_ptr node, bool *const modified) override; + + /// \brief Access selected offload nodes for removal. + /// \return All the nodes to be removed by offload. + std::vector> nodes_to_offload() { return nodes_to_offload_; } + + private: + std::vector> nodes_to_offload_; + bool prev_map_offloaded_; + }; + + public: + /// \brief Constructor + NodeOffloadPass(); + + /// \brief Destructor + ~NodeOffloadPass() = default; + + /// \brief Runs an offload_nodes pass first to find out which nodes to offload, then offloads them. + /// \param[in, out] root_ir The tree to operate on. + /// \param[in, out] modified Indicates if the tree was modified. + /// \return Status The status code returned + Status RunOnTree(std::shared_ptr root_ir, bool *const modified) override; + /// \brief Getter + /// \return JSON of offload + nlohmann::json GetOffloadJson() { return offload_json_list_; } + + private: + /// \brief JSON instance containing single offload op. + nlohmann::json offload_json_; + + /// \brief JSON instance containing all offload ops. + nlohmann::json offload_json_list_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_OFFLOAD_PASS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.cc index 73369135917..4d9a402caff 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.cc @@ -61,6 +61,11 @@ Status NodeRemovalPass::RunOnTree(std::shared_ptr root_ir, bool *co std::unique_ptr removal_nodes = std::make_unique(); RETURN_IF_NOT_OK(removal_nodes->Run(root_ir, modified)); + // Update modified flag if there were any nodes identified to be removed + if (removal_nodes->nodes_to_remove().empty() == false) { + *modified = true; + } + // Then, execute the removal of any nodes that were set up for removal for (auto node : removal_nodes->nodes_to_remove()) { RETURN_IF_NOT_OK(node->Drop()); diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.h index cff5a1b8fab..ab88b72bf90 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/node_removal_pass.h @@ -36,7 +36,6 @@ class NodeRemovalPass : public IRTreePass { class RemovalNodes : public IRNodePass { public: /// \brief Constructor - /// \param[in] removal_pass Raw pointer back to controlling tree pass RemovalNodes(); /// \brief Destructor diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc index 4e452f27bb6..226c6339e85 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.cc @@ -21,6 +21,7 @@ #ifndef ENABLE_ANDROID #include "minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" +#include "minddata/dataset/engine/opt/pre/node_offload_pass.h" #include "minddata/dataset/engine/opt/post/repeat_pass.h" #endif #include "minddata/dataset/engine/opt/pass.h" @@ -60,6 +61,14 @@ Status TreeAdapter::PrePass(std::shared_ptr ir) { if (usage_ == kDeGetter) actions.emplace_back(std::make_unique()); #ifndef ENABLE_ANDROID actions.emplace_back(std::make_unique()); + + std::unique_ptr offload = std::make_unique(); + // Checks nodes for offload removal + bool offload_mod = false; + // Checks ir_tree nodes for offload removal + offload->Run(ir, &offload_mod); + // Creates JSON object of offload nodes. + offload_json_ = offload->GetOffloadJson(); #endif // Vector of flags for each action std::vector modified(actions.size(), false); @@ -69,7 +78,8 @@ Status TreeAdapter::PrePass(std::shared_ptr ir) { RETURN_IF_NOT_OK(actions[i]->Run(ir, &m)); modified[i] = m; } - MS_LOG(INFO) << "Pre pass complete."; + + MS_LOG(INFO) << "Pre pass offload complete."; return Status::OK(); } @@ -260,5 +270,7 @@ Status TreeAdapter::Launch() { return Status::OK(); } +nlohmann::json TreeAdapter::GetOffloadJson() { return offload_json_; } + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h index 1d583d86e73..c53ff688501 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h @@ -78,6 +78,9 @@ class TreeAdapter { // Optional optimizations status bool OptimizationEnabled() const { return optimize_; } + // Return Offload Json + nlohmann::json GetOffloadJson(); + #ifndef ENABLE_SECURITY /// \brief Setter for Profiling Manager Status SetProfilingManagerPtr(std::shared_ptr profiling_manager, @@ -129,6 +132,7 @@ class TreeAdapter { kCompileStateReady // Execution tree is generated from the optimized IR }; CompileState tree_state_; + nlohmann::json offload_json_; }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/dataset/core/validator_helpers.py b/mindspore/dataset/core/validator_helpers.py index ca1c36e56f7..4eb0e6f3d18 100644 --- a/mindspore/dataset/core/validator_helpers.py +++ b/mindspore/dataset/core/validator_helpers.py @@ -729,3 +729,12 @@ def check_c_tensor_op(param, param_name): def replace_none(value, default): """ replaces None with a default value.""" return value if value is not None else default + +def check_dataset_num_shards_shard_id(num_shards, shard_id): + if (num_shards is None) != (shard_id is None): + # These two parameters appear together. + raise ValueError("num_shards and shard_id need to be passed in together.") + if num_shards is not None: + check_pos_int32(num_shards, "num_shards") + if shard_id >= num_shards: + raise ValueError("shard_id should be less than num_shards.") diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 1d750d37048..87244486b18 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -51,6 +51,7 @@ from mindspore.common import Tensor from mindspore import log as logger from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched from mindspore.parallel._utils import _get_device_num +from mindspore.dataset.engine.offload import GetOffloadModel, op_to_model import mindspore.dataset.transforms.py_transforms as py_transforms @@ -92,6 +93,29 @@ ShuffleToShuffleMode = {Shuffle.FILES: cde.ShuffleMode.FILES, Shuffle.INFILE: cde.ShuffleMode.INFILE} +def get_offloadable_ops(operations): + """ + Check if operations are supported by offload hardware accelarator. + + Args: + operations: list of operations. + + Returns: + Dictionary with boolean key for each operation for offload support. + """ + is_offloadable = {} + if not isinstance(operations, list): + operations = [operations] + for op in operations: + name = op.__class__.__name__ + if name in op_to_model: + is_offloadable[name] = True + else: + is_offloadable[name] = False + + return is_offloadable + + def shuffle_to_shuffle_mode(shuffle): """ Shuffle Enum to Shuffle Mode @@ -650,7 +674,8 @@ class Dataset: @check_map def map(self, operations, input_columns=None, output_columns=None, column_order=None, - num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None, max_rowsize=16): + num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None, + max_rowsize=16, offload=False): """ Apply each operation in operations to this dataset. @@ -690,8 +715,9 @@ class Dataset: cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. (default=None, which means no cache is used). callbacks (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called (Default=None). - max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy - data between processes. This is only used if python_multiprocessing is set to True (default=16). + max_rowsize (int, optional): Maximum size of row in MB that is used for shared memory allocation to copy + data between processes. This is only used if python_multiprocessing is set to True (Default=16). + offload (bool, optional): Flag to indicate whether offload is used (Default=False). Returns: @@ -785,7 +811,7 @@ class Dataset: """ return MapDataset(self, operations, input_columns, output_columns, column_order, num_parallel_workers, - python_multiprocessing, cache, callbacks, max_rowsize) + python_multiprocessing, cache, callbacks, max_rowsize, offload) @check_filter def filter(self, predicate, input_columns=None, num_parallel_workers=None): @@ -2767,13 +2793,15 @@ class MapDataset(Dataset): callbacks (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called (Default=None) max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy data between processes. This is only used if python_multiprocessing is set to True (default=16). + offload (bool, optional): Flag to indicate whether offload is used (Default=False). Raises: ValueError: If len(input_columns) != len(output_columns) and column_order is not specified. """ def __init__(self, input_dataset, operations=None, input_columns=None, output_columns=None, column_order=None, - num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None, max_rowsize=16): + num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None, max_rowsize=16, + offload=False): super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers, cache=cache) self.operations = to_list(operations) self.operations = py_transforms.Compose.reduce(self.operations) @@ -2799,6 +2827,20 @@ class MapDataset(Dataset): self.callbacks = to_list(callbacks) self.max_rowsize = max_rowsize + self.offload = offload + + if self.offload is True: + offloadable_ops = get_offloadable_ops(operations) + cannot_offload = False + invalid_ops = [] + for op in offloadable_ops: + if offloadable_ops[op] is not True: + cannot_offload = True + invalid_ops.append(op) + if cannot_offload is True: + logger.warning(("In map(), offload is set to True, but offload is not supported for the following " + "operation(s): {} \nSetting offload to False").format(*invalid_ops)) + self.offload = False def parse(self, children=None): operations = [] @@ -2810,7 +2852,7 @@ class MapDataset(Dataset): callbacks = [cb.create_runtime_obj() for cb in self.callbacks] return cde.MapNode(children[0], operations, self.input_columns, self.output_columns, self.column_order, - callbacks) + callbacks, self.max_rowsize, self.offload) def __deepcopy__(self, memodict): return self.__safe_deepcopy__(memodict, exclude=("operations", "callbacks", "__transfer_dataset__")) @@ -3210,6 +3252,13 @@ class _ToDevice: def __deepcopy__(self, memodict): return self + def get_offload_model(self): + """ + Get offload model containing removed offload ops from pipeline. + """ + offload_model = GetOffloadModel(self._to_device) + return offload_model + class TransferDataset(Dataset): """ @@ -3287,6 +3336,12 @@ class TransferDataset(Dataset): return self._to_device.get_data_info() raise RuntimeError("Calling get_data_info with bad state.") + def get_offload_model(self): + if self._to_device is not None: + return self._to_device.get_offload_model() + + raise RuntimeError("get_offload_model, _to_device is None") + def release(self): """ Manually terminate Device Queue instead of relying on out of scope destruction. @@ -6478,6 +6533,7 @@ class _Flowers102Dataset: """ Mainly for loading Flowers102 Dataset, and return one row each time. """ + def __init__(self, dataset_dir, task, usage, decode): self.dataset_dir = os.path.realpath(dataset_dir) self.task = task diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index 70588e70fa2..7e89f87ae44 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -20,8 +20,9 @@ import signal import weakref import numpy as np -from mindspore.common.tensor import Tensor import mindspore._c_dataengine as cde +from mindspore.common.tensor import Tensor +import mindspore.dataset.engine.offload as offload from mindspore import log as logger @@ -86,6 +87,10 @@ class Iterator: self._transform_tensor = lambda t: Tensor.from_numpy(t.as_array()) self.__index = 0 + self.offload_model = None + if offload.check_map_offload(self.__ori_dataset): + self.offload_model = offload.GetOffloadModel(consumer) + ITERATORS_LIST.append(weakref.ref(self)) _unset_iterator_cleanup() @@ -139,6 +144,10 @@ class Iterator: self.__ori_dataset.dataset_size = self.__index raise StopIteration self.__index += 1 + + if self.offload_model is not None: + data = offload.apply_offload_iterators(data, self.offload_model) + return data def __deepcopy__(self, memo): diff --git a/mindspore/dataset/engine/offload.py b/mindspore/dataset/engine/offload.py new file mode 100644 index 00000000000..725057b4fc4 --- /dev/null +++ b/mindspore/dataset/engine/offload.py @@ -0,0 +1,368 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Offload Support. +""" +import json +import numpy as np + +import mindspore.common.dtype as mstype +from mindspore.common.tensor import Tensor +import mindspore.nn as nn +import mindspore.ops.composite as C +from mindspore.ops import operations as P + + +def check_map_offload(dataset): + """ + Check if offload flag is set in data pipeline map ops. + """ + offload_ckeck = False + dataset_tmp = dataset + while dataset_tmp: + if hasattr(dataset_tmp, 'offload'): + if dataset_tmp.offload is True: + offload_ckeck = True + if dataset_tmp.children: + dataset_tmp = dataset_tmp.children[0] + continue + dataset_tmp = dataset_tmp.children + + if offload_ckeck is True: + if len(dataset.children) > 1: + raise RuntimeError("Offload currently does not support concatenated datasets.") + + return offload_ckeck + + +def apply_offload_iterators(data, offload_model): + """ + Apply offload for non sink mode pipeline. + """ + if len(data) != 2: + # A temporary solution to ensure there are two columns in dataset. + raise RuntimeError("Offload can currently only use datasets with two columns.") + if isinstance(data[0], Tensor) is True: + data[0] = offload_model(data[0]) + else: + data[0] = Tensor(data[0], dtype=mstype.float32) + data[0] = offload_model(data[0]).asnumpy() + + return data + + +class ApplyPreTransform(nn.Cell): + """ + Concatenates offload model with network. + """ + def __init__(self, transform, model): + super(ApplyPreTransform, self).__init__(auto_prefix=False, flags=model.get_flags()) + self.transform = transform + self.model = model + + def construct(self, x, label): + x = self.transform(x) + x = self.model(x, label) + return x + + +class IdentityCell(nn.Cell): + """ + Applies identity transform on given input tensors. + """ + def __init__(self): + super(IdentityCell, self).__init__() + self.identity = P.Identity() + + def construct(self, x): + return self.identity(x) + + +class RandomHorizontalFlip(nn.Cell): + """ + Applies Random Horizontal Flip transform on given input tensors. + """ + def __init__(self, prob): + super(RandomHorizontalFlip, self).__init__() + + self.prob = Tensor(prob, dtype=mstype.float32) + + self.cast = P.Cast() + self.shape = P.Shape() + self.uniformReal = P.UniformReal() + self.reshape = P.Reshape() + self.h_flip = P.ReverseV2(axis=[2]) + self.mul = P.Mul() + + def construct(self, x): + + x = self.cast(x, mstype.float32) + bs, h, w, c = self.shape(x) + + flip_rand_factor = self.uniformReal((bs, 1)) + flip_rand_factor = self.cast((self.prob > flip_rand_factor), mstype.float32) + flip_rand_factor = self.reshape(C.repeat_elements(flip_rand_factor, rep=(h*w*c)), (bs, h, w, c)) + + x_flip = self.h_flip(x) + x = self.mul(x_flip, flip_rand_factor) + self.mul((1 - flip_rand_factor), x) + + return x + + +class RandomVerticalFlip(nn.Cell): + """ + Applies Random Vertical Flip transform on given input tensors. + """ + def __init__(self, prob): + super(RandomVerticalFlip, self).__init__() + + self.prob = Tensor(prob, dtype=mstype.float32) + + self.cast = P.Cast() + self.shape = P.Shape() + self.uniformReal = P.UniformReal() + self.reshape = P.Reshape() + self.h_flip = P.ReverseV2(axis=[1]) + self.mul = P.Mul() + + def construct(self, x): + + x = self.cast(x, mstype.float32) + bs, h, w, c = self.shape(x) + + flip_rand_factor = self.uniformReal((bs, 1)) + flip_rand_factor = self.cast((self.prob > flip_rand_factor), mstype.float32) + flip_rand_factor = self.reshape(C.repeat_elements(flip_rand_factor, rep=(h*w*c)), (bs, h, w, c)) + + x_flip = self.h_flip(x) + x = self.mul(x_flip, flip_rand_factor) + self.mul((1 - flip_rand_factor), x) + + return x + + +class RandomColorAdjust(nn.Cell): + """ + Applies Random Color Adjust transform on given input tensors. + """ + def __init__(self, brightness, saturation): + super(RandomColorAdjust, self).__init__() + + if isinstance(brightness, (list, tuple)): + self.br_min = brightness[0] + self.br_max = brightness[1] + else: + self.br_min = max(0, 1 - brightness) + self.br_max = 1 + brightness + + if isinstance(saturation, (list, tuple)): + self.sa_min = saturation[0] + self.sa_max = saturation[1] + else: + self.sa_min = max(0, 1 - saturation) + self.sa_max = 1 + saturation + + self.cast = P.Cast() + self.shape = P.Shape() + self.uniformReal = P.UniformReal() + self.reshape = P.Reshape() + self.unstack = P.Unstack(axis=-1) + self.expand_dims = P.ExpandDims() + self.mul = P.Mul() + + def construct(self, x): + + x = self.cast(x, mstype.float32) + bs, h, w, c = self.shape(x) + + br_rand_factor = self.br_min + (self.br_max - self.br_min)*self.uniformReal((bs, 1)) + br_rand_factor = self.reshape(C.repeat_elements(br_rand_factor, rep=(h*w*c)), (bs, h, w, c)) + + sa_rand_factor = self.sa_min + (self.sa_max - self.sa_min)*self.uniformReal((bs, 1)) + sa_rand_factor = self.reshape(C.repeat_elements(sa_rand_factor, rep=(h*w*c)), (bs, h, w, c)) + + r, g, b = self.unstack(x) + x_gray = C.repeat_elements(self.expand_dims((0.2989 * r + 0.587 * g + 0.114 * b), -1), rep=c, axis=-1) + + x = self.mul(x, br_rand_factor) + x = C.clip_by_value(x, 0.0, 255.0) + + x = self.mul(x, sa_rand_factor) + self.mul((1 - sa_rand_factor), x_gray) + x = C.clip_by_value(x, 0.0, 255.0) + + return x + + +class RandomSharpness(nn.Cell): + """ + Applies Random Sharpness transform on given input tensors. + """ + def __init__(self, degrees): + super(RandomSharpness, self).__init__() + + if isinstance(degrees, (list, tuple)): + self.degree_min = degrees[0] + self.degree_max = degrees[1] + else: + self.degree_min = max(0, 1 - degrees) + self.degree_max = 1 + degrees + + self.cast = P.Cast() + self.shape = P.Shape() + self.uniformReal = P.UniformReal() + self.reshape = P.Reshape() + self.expand_dims = P.ExpandDims() + self.mul = P.Mul() + self.transpose = P.Transpose() + + self.weight = np.array([[1, 1, 1], [1, 5, 1], [1, 1, 1]])/13.0 + self.weight = np.repeat(self.weight[np.newaxis, :, :], 3, axis=0) + self.weight = np.repeat(self.weight[np.newaxis, :, :], 3, axis=0) + self.weight = Tensor(self.weight, mstype.float32) + + self.filter = P.Conv2D(out_channel=3, kernel_size=(3, 3), pad_mode='same') + + def construct(self, x): + + x = self.cast(x, mstype.float32) + bs, h, w, c = self.shape(x) + + degree_rand_factor = self.degree_min + (self.degree_max - self.degree_min)*self.uniformReal((bs, 1)) + degree_rand_factor = self.reshape(C.repeat_elements(degree_rand_factor, rep=(h*w*c)), (bs, h, w, c)) + + x_sharp = self.filter(self.transpose(x, (0, 3, 1, 2)), self.weight) + x_sharp = self.transpose(x_sharp, (0, 2, 3, 1)) + + x = self.mul(x, degree_rand_factor) + self.mul((1 - degree_rand_factor), x_sharp) + x = C.clip_by_value(x, 0.0, 255.0) + + return x + + +class Rescale(nn.Cell): + """ + Applies Rescale transform on given input tensors. + """ + def __init__(self, rescale, shift): + super(Rescale, self).__init__() + + self.rescale = Tensor(rescale, dtype=mstype.float32) + self.shift = Tensor(shift, dtype=mstype.float32) + + self.cast = P.Cast() + self.mul = P.Mul() + + def construct(self, x): + + x = self.cast(x, mstype.float32) + x = x * self.rescale + self.shift + + return x + + +class HwcToChw(nn.Cell): + """ + Applies Channel Swap transform on given input tensors. + """ + def __init__(self): + super(HwcToChw, self).__init__() + self.trans = P.Transpose() + + def construct(self, x): + return self.trans(x, (0, 3, 1, 2)) + + +class Normalize(nn.Cell): + """ + Applies Normalize transform on given input tensors. + """ + def __init__(self, mean, std): + super(Normalize, self).__init__() + self.mean = Tensor(mean, mstype.float32) + self.std = Tensor(std, mstype.float32) + self.sub = P.Sub() + self.div = P.Div() + self.cast = P.Cast() + + def construct(self, x): + x = self.cast(x, mstype.float32) + x = self.sub(x, self.mean) + x = self.div(x, self.std) + return x + + +class OffloadModel(): + def __init__(self, func, args_names=None): + self.func = func + self.args_names = args_names + + +# Dictionary connecting operation name to model +op_to_model = { + "HWC2CHW": OffloadModel(HwcToChw), + "HwcToChw": OffloadModel(HwcToChw), + "Normalize": OffloadModel(Normalize, ["std", "mean"]), + "RandomColorAdjust": OffloadModel(RandomColorAdjust, ["brightness", "saturation"]), + "RandomHorizontalFlip": OffloadModel(RandomHorizontalFlip, ["prob"]), + "RandomSharpness": OffloadModel(RandomSharpness, ["degrees"]), + "RandomVerticalFlip": OffloadModel(RandomVerticalFlip, ["prob"]), + "Rescale": OffloadModel(Rescale, ["rescale", "shift"]) +} + + +class GetModelFromJson2Col(nn.Cell): + """ + Generates offload ME model from offload JSON file for a single map op. + """ + def __init__(self, json_offload): + super(GetModelFromJson2Col, self).__init__() + self.me_ops = [] + if json_offload is not None: + offload_ops = json_offload["operations"] + for op in offload_ops: + name = op["tensor_op_name"] + args = op["tensor_op_params"] + op_model = op_to_model[name] + op_model_inputs = [] + if op_model.args_names is not None: + for arg_key in op_model.args_names: + op_model_inputs.append(args[arg_key]) + + self.me_ops.append(op_model.func(*op_model_inputs)) + else: + raise RuntimeError("Offload hardware accelarator cannot be applied for this pipeline.") + + self.cell = nn.SequentialCell(self.me_ops) + + def construct(self, x): + return self.cell(x) + + +class GetOffloadModel(nn.Cell): + """ + Generates offload ME model. + """ + def __init__(self, dataset_consumer): + super(GetOffloadModel, self).__init__() + self.transform_list = [] + json_offload = json.loads(dataset_consumer.GetOffload()) + if json_offload is not None: + for node in json_offload: + if node["op_type"] == 'Map': + self.transform_list.append(GetModelFromJson2Col(node)) + self.transform_list.reverse() + + def construct(self, x): + for transform in self.transform_list: + x = transform(x) + return x diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index b158b4bb578..3f226a1e1f0 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -26,7 +26,7 @@ from mindspore._c_expression import typing from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \ INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \ validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_gnn_list_of_pair_or_ndarray, \ - check_num_parallel_workers, check_columns, check_pos_int32, check_valid_str + check_num_parallel_workers, check_columns, check_pos_int32, check_valid_str, check_dataset_num_shards_shard_id from . import datasets from . import samplers @@ -548,13 +548,7 @@ def check_generatordataset(method): num_shards = param_dict.get("num_shards") shard_id = param_dict.get("shard_id") - if (num_shards is None) != (shard_id is None): - # These two parameters appear together. - raise ValueError("num_shards and shard_id need to be passed in together.") - if num_shards is not None: - check_pos_int32(num_shards, "num_shards") - if shard_id >= num_shards: - raise ValueError("shard_id should be less than num_shards.") + check_dataset_num_shards_shard_id(num_shards, shard_id) sampler = param_dict.get("sampler") if sampler is not None: @@ -776,7 +770,7 @@ def check_map(method): def new_method(self, *args, **kwargs): from mindspore.dataset.callback import DSCallback [_, input_columns, output_columns, column_order, num_parallel_workers, python_multiprocessing, cache, - callbacks, max_rowsize], _ = \ + callbacks, max_rowsize, offload], _ = \ parse_user_args(method, *args, **kwargs) nreq_param_columns = ['input_columns', 'output_columns', 'column_order'] @@ -788,6 +782,7 @@ def check_map(method): type_check(python_multiprocessing, (bool,), "python_multiprocessing") check_cache_option(cache) type_check(max_rowsize, (int,), "max_rowsize") + type_check(offload, (bool,), "offload") if callbacks is not None: if isinstance(callbacks, (list, tuple)): diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index a2dfadd50d9..99a58181793 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -118,6 +118,18 @@ def _generate_network_with_dataset(network, dataset_helper, queue_name): return network +def _check_add_offload(dataset, dataset_helper, network): + from mindspore.dataset.engine import offload + if offload.check_map_offload(dataset.__transfer_dataset__): + # A temporary solution to ensure there are two columns in dataset. + dataset_types, _ = dataset_helper.types_shapes() + if len(dataset_types) != 2: + raise RuntimeError("Offload can currently only use datasets with two columns.") + offload_model = dataset.__transfer_dataset__.get_offload_model() + network = offload.ApplyPreTransform(offload_model, network) + return network + + def connect_network_with_dataset(network, dataset_helper): """ Connect the `network` with dataset in `dataset_helper`. @@ -153,7 +165,6 @@ def connect_network_with_dataset(network, dataset_helper): >>> net = Net() >>> net_with_get_next = connect_network_with_dataset(net, dataset_helper) """ - dataset_iter = dataset_helper.iter dataset = dataset_iter.dataset @@ -191,6 +202,7 @@ def connect_network_with_dataset(network, dataset_helper): not context.get_context("enable_ge") and \ context.get_context("device_target") in ("Ascend", "GPU"): dataset.__me_inited__ = True + network = _check_add_offload(dataset, dataset_helper, network) network = _generate_network_with_dataset(network, dataset_helper, queue_name) if _dynamic_sink_data(dataset, dataset_iter) and _dynamic_sink_exception_scenario(dataset_iter): diff --git a/tests/ut/python/dataset/test_map_offload.py b/tests/ut/python/dataset/test_map_offload.py new file mode 100644 index 00000000000..565feefcd9d --- /dev/null +++ b/tests/ut/python/dataset/test_map_offload.py @@ -0,0 +1,47 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np + +import mindspore.dataset as ds +import mindspore.dataset.vision.c_transforms as C + + +DATA_DIR = "../data/dataset/testPK/data" + +def test_offload(): + """ + Feature: test map offload flag. + Description: Input is image dataset. + Expectation: Output should be same with activated or deactivated offload. + """ + # Dataset with offload activated. + dataset_0 = ds.ImageFolderDataset(DATA_DIR) + dataset_0 = dataset_0.map(operations=[C.Decode()], input_columns="image") + dataset_0 = dataset_0.map(operations=[C.HWC2CHW()], input_columns="image", offload=True) + dataset_0 = dataset_0.batch(8, drop_remainder=True) + + # Dataset with offload not activated. + dataset_1 = ds.ImageFolderDataset(DATA_DIR) + dataset_1 = dataset_1.map(operations=[C.Decode()], input_columns="image") + dataset_1 = dataset_1.map(operations=[C.HWC2CHW()], input_columns="image") + dataset_1 = dataset_1.batch(8, drop_remainder=True) + + for (img_0, _), (img_1, _) in zip(dataset_0.create_tuple_iterator(num_epochs=1, output_numpy=True), + dataset_1.create_tuple_iterator(num_epochs=1, output_numpy=True)): + np.testing.assert_array_equal(img_0, img_1) + + +if __name__ == "__main__": + test_offload()