Adding offload hardware accelerator.

This commit is contained in:
rescue 2021-10-22 18:42:10 +00:00 committed by Rescue
parent 14efcd5a1c
commit 22e03ab12e
20 changed files with 717 additions and 28 deletions

View File

@ -38,6 +38,7 @@ PYBIND_REGISTER(PythonIteratorConsumer, 1, ([](const py::module *m) {
THROW_IF_ERROR(self.GetNextAsDict(&output)); THROW_IF_ERROR(self.GetNextAsDict(&output));
return output; return output;
}) })
.def("GetOffload", [](PythonIteratorConsumer &self) { return self.GetOffload(); })
.def("GetNextAsList", [](PythonIteratorConsumer &self) { .def("GetNextAsList", [](PythonIteratorConsumer &self) {
py::list output; py::list output;
THROW_IF_ERROR(self.GetNextAsList(&output)); THROW_IF_ERROR(self.GetNextAsList(&output));
@ -123,6 +124,7 @@ PYBIND_REGISTER(ToDevice, 1, ([](const py::module *m) {
.def("Send", [](ToDevice &self) { THROW_IF_ERROR(self.Send()); }) .def("Send", [](ToDevice &self) { THROW_IF_ERROR(self.Send()); })
.def("ContinueSend", [](ToDevice &self) { THROW_IF_ERROR(self.Continue()); }) .def("ContinueSend", [](ToDevice &self) { THROW_IF_ERROR(self.Continue()); })
.def("StopSend", [](ToDevice &self) { THROW_IF_ERROR(self.Stop()); }) .def("StopSend", [](ToDevice &self) { THROW_IF_ERROR(self.Stop()); })
.def("GetOffload", [](ToDevice &self) { return self.GetOffload(); })
.def("GetDataInfo", .def("GetDataInfo",
[](ToDevice &self) { [](ToDevice &self) {
std::vector<DataType> types_c; std::vector<DataType> types_c;
@ -170,6 +172,5 @@ PYBIND_REGISTER(PythonDatasetSizeGetter, 1, ([](const py::module *m) {
return size; return size;
}); });
})); }));
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -51,7 +51,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
PYBIND_REGISTER(DatasetNode, 1, ([](const py::module *m) { PYBIND_REGISTER(DatasetNode, 1, ([](const py::module *m) {
(void)py::class_<DatasetNode, std::shared_ptr<DatasetNode>>(*m, "Dataset") (void)py::class_<DatasetNode, std::shared_ptr<DatasetNode>>(*m, "Dataset")
.def("set_num_workers", .def("set_num_workers",
@ -193,11 +192,12 @@ PYBIND_REGISTER(MapNode, 2, ([](const py::module *m) {
(void)py::class_<MapNode, DatasetNode, std::shared_ptr<MapNode>>(*m, "MapNode", "to create a MapNode") (void)py::class_<MapNode, DatasetNode, std::shared_ptr<MapNode>>(*m, "MapNode", "to create a MapNode")
.def(py::init([](std::shared_ptr<DatasetNode> self, py::list operations, py::list input_columns, .def(py::init([](std::shared_ptr<DatasetNode> self, py::list operations, py::list input_columns,
py::list output_columns, py::list project_columns, py::list output_columns, py::list project_columns,
std::vector<std::shared_ptr<PyDSCallback>> py_callbacks) { std::vector<std::shared_ptr<PyDSCallback>> py_callbacks, int64_t max_rowsize,
bool offload) {
auto map = std::make_shared<MapNode>( auto map = std::make_shared<MapNode>(
self, std::move(toTensorOperations(operations)), toStringVector(input_columns), self, std::move(toTensorOperations(operations)), toStringVector(input_columns),
toStringVector(output_columns), toStringVector(project_columns), nullptr, toStringVector(output_columns), toStringVector(project_columns), nullptr,
std::vector<std::shared_ptr<DSCallback>>(py_callbacks.begin(), py_callbacks.end())); std::vector<std::shared_ptr<DSCallback>>(py_callbacks.begin(), py_callbacks.end()), offload);
THROW_IF_ERROR(map->ValidateParams()); THROW_IF_ERROR(map->ValidateParams());
return map; return map;
})); }));
@ -297,6 +297,5 @@ PYBIND_REGISTER(ZipNode, 2, ([](const py::module *m) {
return zip; return zip;
})); }));
})); }));
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -98,6 +98,8 @@ Status TreeConsumer::RegisterProfilingManager() {
} }
#endif #endif
std::string TreeConsumer::GetOffload() { return (tree_adapter_->GetOffloadJson()).dump(); }
// IteratorConsumer // IteratorConsumer
Status IteratorConsumer::Init(std::shared_ptr<DatasetNode> d) { Status IteratorConsumer::Init(std::shared_ptr<DatasetNode> d) {
RETURN_IF_NOT_OK(tree_adapter_->Compile(std::move(d), num_epochs_)); RETURN_IF_NOT_OK(tree_adapter_->Compile(std::move(d), num_epochs_));

View File

@ -50,6 +50,10 @@ class TreeConsumer {
/// \return Status error code /// \return Status error code
virtual Status Terminate(); virtual Status Terminate();
/// Function for all consumers to get the offload JSON string.
/// \return Offload JSON string.
std::string GetOffload();
#ifndef ENABLE_SECURITY #ifndef ENABLE_SECURITY
virtual Status RegisterProfilingManager(); virtual Status RegisterProfilingManager();

View File

@ -35,20 +35,21 @@ namespace dataset {
MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations, MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations,
std::vector<std::string> input_columns, std::vector<std::string> output_columns, std::vector<std::string> input_columns, std::vector<std::string> output_columns,
const std::vector<std::string> &project_columns, std::shared_ptr<DatasetCache> cache, const std::vector<std::string> &project_columns, std::shared_ptr<DatasetCache> cache,
std::vector<std::shared_ptr<DSCallback>> callbacks) std::vector<std::shared_ptr<DSCallback>> callbacks, bool offload)
: operations_(operations), : operations_(operations),
input_columns_(input_columns), input_columns_(input_columns),
output_columns_(output_columns), output_columns_(output_columns),
project_columns_(project_columns), project_columns_(project_columns),
DatasetNode(std::move(cache)), DatasetNode(std::move(cache)),
callbacks_(callbacks) { callbacks_(callbacks),
offload_(offload) {
this->AddChild(child); this->AddChild(child);
} }
std::shared_ptr<DatasetNode> MapNode::Copy() { std::shared_ptr<DatasetNode> MapNode::Copy() {
std::vector<std::shared_ptr<TensorOperation>> operations = operations_; std::vector<std::shared_ptr<TensorOperation>> operations = operations_;
auto node = std::make_shared<MapNode>(nullptr, operations, input_columns_, output_columns_, project_columns_, cache_, auto node = std::make_shared<MapNode>(nullptr, operations, input_columns_, output_columns_, project_columns_, cache_,
callbacks_); callbacks_, offload_);
return node; return node;
} }
@ -151,6 +152,8 @@ void MapNode::setOperations(const std::vector<std::shared_ptr<TensorOperation>>
} }
std::vector<std::shared_ptr<TensorOperation>> MapNode::operations() { return operations_; } std::vector<std::shared_ptr<TensorOperation>> MapNode::operations() { return operations_; }
void MapNode::SetOffload(bool offload) { offload_ = offload; }
Status MapNode::to_json(nlohmann::json *out_json) { Status MapNode::to_json(nlohmann::json *out_json) {
RETURN_UNEXPECTED_IF_NULL(out_json); RETURN_UNEXPECTED_IF_NULL(out_json);
nlohmann::json args; nlohmann::json args;
@ -182,6 +185,7 @@ Status MapNode::to_json(nlohmann::json *out_json) {
(void)std::transform(callbacks_.begin(), callbacks_.end(), std::back_inserter(cbs), (void)std::transform(callbacks_.begin(), callbacks_.end(), std::back_inserter(cbs),
[](std::shared_ptr<DSCallback> cb) -> int32_t { return cb != nullptr ? cb->step_size() : 0; }); [](std::shared_ptr<DSCallback> cb) -> int32_t { return cb != nullptr ? cb->step_size() : 0; });
args["callback"] = cbs; args["callback"] = cbs;
*out_json = args; *out_json = args;
return Status::OK(); return Status::OK();
} }

View File

@ -32,7 +32,7 @@ class MapNode : public DatasetNode {
MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations, MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations,
std::vector<std::string> input_columns = {}, std::vector<std::string> output_columns = {}, std::vector<std::string> input_columns = {}, std::vector<std::string> output_columns = {},
const std::vector<std::string> &columns = {}, std::shared_ptr<DatasetCache> cache = nullptr, const std::vector<std::string> &columns = {}, std::shared_ptr<DatasetCache> cache = nullptr,
std::vector<std::shared_ptr<DSCallback>> callbacks = {}); std::vector<std::shared_ptr<DSCallback>> callbacks = {}, bool offload = false);
/// \brief Destructor /// \brief Destructor
~MapNode() = default; ~MapNode() = default;
@ -87,6 +87,10 @@ class MapNode : public DatasetNode {
const std::vector<std::string> &OutputColumns() const { return output_columns_; } const std::vector<std::string> &OutputColumns() const { return output_columns_; }
const std::vector<std::string> &ProjectColumns() const { return project_columns_; } const std::vector<std::string> &ProjectColumns() const { return project_columns_; }
const std::vector<std::shared_ptr<DSCallback>> &Callbacks() const { return callbacks_; } const std::vector<std::shared_ptr<DSCallback>> &Callbacks() const { return callbacks_; }
bool GetOffload() const { return offload_; }
/// \brief setter to set offload flag of node
void SetOffload(bool offload);
/// \brief Get the arguments of node /// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes /// \param[out] out_json JSON string of all attributes
@ -118,6 +122,9 @@ class MapNode : public DatasetNode {
std::vector<std::string> output_columns_; std::vector<std::string> output_columns_;
std::vector<std::string> project_columns_; std::vector<std::string> project_columns_;
std::vector<std::shared_ptr<DSCallback>> callbacks_; std::vector<std::shared_ptr<DSCallback>> callbacks_;
/// \brief Flag to indicate whether offload is set for the Map node.
bool offload_;
}; };
} // namespace dataset } // namespace dataset

View File

@ -12,6 +12,7 @@ set(DATASET_ENGINE_OPT_SRC_FILES
pre/epoch_ctrl_pass.cc pre/epoch_ctrl_pass.cc
pre/getter_pass.cc pre/getter_pass.cc
pre/input_validation_pass.cc pre/input_validation_pass.cc
pre/node_offload_pass.cc
pre/node_removal_pass.cc pre/node_removal_pass.cc
) )

View File

@ -0,0 +1,73 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/opt/pre/node_offload_pass.h"
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
namespace mindspore {
namespace dataset {
NodeOffloadPass::OffloadNodes::OffloadNodes() : prev_map_offloaded_(true) {}
// Perform MapNode offload check.
Status NodeOffloadPass::OffloadNodes::Visit(std::shared_ptr<MapNode> node, bool *const modified) {
*modified = false;
// Check if this node is set to offload and add to nodes_to_offload_.
if (node->GetOffload() == true) {
MS_LOG(INFO) << "Pre pass: node offload of map class is true.";
if (prev_map_offloaded_) {
nodes_to_offload_.push_back(std::static_pointer_cast<DatasetNode>(node));
} else {
MS_LOG(WARNING) << "Invalid use of offload in map, ignoring offload flag. Ops will be run in CPU pipeline";
node->SetOffload(false);
*modified = true;
}
} else {
// Since map nodes are visited in reverse order, no other map ops can be offloaded after this.
prev_map_offloaded_ = false;
}
return Status::OK();
}
// constructor
NodeOffloadPass::NodeOffloadPass() {}
// Walk the tree to collect the nodes to offload, fill the offload_json object, then remove the node.
Status NodeOffloadPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) {
MS_LOG(INFO) << "Pre pass: node offload pass started.";
// Create the offload node pass which can identify which nodes need to be offloaded.
std::unique_ptr<NodeOffloadPass::OffloadNodes> offload_nodes = std::make_unique<NodeOffloadPass::OffloadNodes>();
RETURN_IF_NOT_OK(offload_nodes->Run(root_ir, modified));
// Update modified flag if there were any nodes identified to be offloaded
if (offload_nodes->nodes_to_offload().empty() == false) {
*modified = true;
}
// Then, execute the offloading of any nodes that were set up to be offloaded
for (auto node : offload_nodes->nodes_to_offload()) {
RETURN_IF_NOT_OK(node->to_json(&offload_json_));
offload_json_["op_type"] = node->Name();
// Add the single offloaded node to the list of offloaded nodes and remove the node from the ir tree
offload_json_list_.push_back(offload_json_);
RETURN_IF_NOT_OK(node->Drop());
}
MS_LOG(INFO) << "Pre pass: offload node removal pass complete.";
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,82 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_OFFLOAD_PASS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_OFFLOAD_PASS_H_
#include <memory>
#include <vector>
#include "minddata/dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
class DatasetOp;
/// \class NodeOffloadPass
/// \brief This is a tree pass that will offload nodes. It uses offload_nodes to first identify which
/// nodes should be offloaded, adds the nodes' namea to the offload list, then removes the nodes from the ir tree.
class NodeOffloadPass : public IRTreePass {
/// \class OffloadNodes
/// \brief This is a NodePass whose job is to identify which nodes should be offloaded.
class OffloadNodes : public IRNodePass {
public:
/// \brief Constructor
OffloadNodes();
/// \brief Destructor
~OffloadNodes() = default;
/// \brief Perform MapNode offload check
/// \param[in] node The node being visited
/// \param[in, out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status Visit(std::shared_ptr<MapNode> node, bool *const modified) override;
/// \brief Access selected offload nodes for removal.
/// \return All the nodes to be removed by offload.
std::vector<std::shared_ptr<DatasetNode>> nodes_to_offload() { return nodes_to_offload_; }
private:
std::vector<std::shared_ptr<DatasetNode>> nodes_to_offload_;
bool prev_map_offloaded_;
};
public:
/// \brief Constructor
NodeOffloadPass();
/// \brief Destructor
~NodeOffloadPass() = default;
/// \brief Runs an offload_nodes pass first to find out which nodes to offload, then offloads them.
/// \param[in, out] root_ir The tree to operate on.
/// \param[in, out] modified Indicates if the tree was modified.
/// \return Status The status code returned
Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) override;
/// \brief Getter
/// \return JSON of offload
nlohmann::json GetOffloadJson() { return offload_json_list_; }
private:
/// \brief JSON instance containing single offload op.
nlohmann::json offload_json_;
/// \brief JSON instance containing all offload ops.
nlohmann::json offload_json_list_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_OFFLOAD_PASS_H_

View File

@ -61,6 +61,11 @@ Status NodeRemovalPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *co
std::unique_ptr<NodeRemovalPass::RemovalNodes> removal_nodes = std::make_unique<NodeRemovalPass::RemovalNodes>(); std::unique_ptr<NodeRemovalPass::RemovalNodes> removal_nodes = std::make_unique<NodeRemovalPass::RemovalNodes>();
RETURN_IF_NOT_OK(removal_nodes->Run(root_ir, modified)); RETURN_IF_NOT_OK(removal_nodes->Run(root_ir, modified));
// Update modified flag if there were any nodes identified to be removed
if (removal_nodes->nodes_to_remove().empty() == false) {
*modified = true;
}
// Then, execute the removal of any nodes that were set up for removal // Then, execute the removal of any nodes that were set up for removal
for (auto node : removal_nodes->nodes_to_remove()) { for (auto node : removal_nodes->nodes_to_remove()) {
RETURN_IF_NOT_OK(node->Drop()); RETURN_IF_NOT_OK(node->Drop());

View File

@ -36,7 +36,6 @@ class NodeRemovalPass : public IRTreePass {
class RemovalNodes : public IRNodePass { class RemovalNodes : public IRNodePass {
public: public:
/// \brief Constructor /// \brief Constructor
/// \param[in] removal_pass Raw pointer back to controlling tree pass
RemovalNodes(); RemovalNodes();
/// \brief Destructor /// \brief Destructor

View File

@ -21,6 +21,7 @@
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" #include "minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h"
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" #include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
#include "minddata/dataset/engine/opt/pre/node_offload_pass.h"
#include "minddata/dataset/engine/opt/post/repeat_pass.h" #include "minddata/dataset/engine/opt/post/repeat_pass.h"
#endif #endif
#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/engine/opt/pass.h"
@ -60,6 +61,14 @@ Status TreeAdapter::PrePass(std::shared_ptr<DatasetNode> ir) {
if (usage_ == kDeGetter) actions.emplace_back(std::make_unique<GetterPass>()); if (usage_ == kDeGetter) actions.emplace_back(std::make_unique<GetterPass>());
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
actions.emplace_back(std::make_unique<CacheTransformPass>()); actions.emplace_back(std::make_unique<CacheTransformPass>());
std::unique_ptr<NodeOffloadPass> offload = std::make_unique<NodeOffloadPass>();
// Checks nodes for offload removal
bool offload_mod = false;
// Checks ir_tree nodes for offload removal
offload->Run(ir, &offload_mod);
// Creates JSON object of offload nodes.
offload_json_ = offload->GetOffloadJson();
#endif #endif
// Vector of flags for each action // Vector of flags for each action
std::vector<bool> modified(actions.size(), false); std::vector<bool> modified(actions.size(), false);
@ -69,7 +78,8 @@ Status TreeAdapter::PrePass(std::shared_ptr<DatasetNode> ir) {
RETURN_IF_NOT_OK(actions[i]->Run(ir, &m)); RETURN_IF_NOT_OK(actions[i]->Run(ir, &m));
modified[i] = m; modified[i] = m;
} }
MS_LOG(INFO) << "Pre pass complete.";
MS_LOG(INFO) << "Pre pass offload complete.";
return Status::OK(); return Status::OK();
} }
@ -260,5 +270,7 @@ Status TreeAdapter::Launch() {
return Status::OK(); return Status::OK();
} }
nlohmann::json TreeAdapter::GetOffloadJson() { return offload_json_; }
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -78,6 +78,9 @@ class TreeAdapter {
// Optional optimizations status // Optional optimizations status
bool OptimizationEnabled() const { return optimize_; } bool OptimizationEnabled() const { return optimize_; }
// Return Offload Json
nlohmann::json GetOffloadJson();
#ifndef ENABLE_SECURITY #ifndef ENABLE_SECURITY
/// \brief Setter for Profiling Manager /// \brief Setter for Profiling Manager
Status SetProfilingManagerPtr(std::shared_ptr<ProfilingManager> profiling_manager, Status SetProfilingManagerPtr(std::shared_ptr<ProfilingManager> profiling_manager,
@ -129,6 +132,7 @@ class TreeAdapter {
kCompileStateReady // Execution tree is generated from the optimized IR kCompileStateReady // Execution tree is generated from the optimized IR
}; };
CompileState tree_state_; CompileState tree_state_;
nlohmann::json offload_json_;
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -729,3 +729,12 @@ def check_c_tensor_op(param, param_name):
def replace_none(value, default): def replace_none(value, default):
""" replaces None with a default value.""" """ replaces None with a default value."""
return value if value is not None else default return value if value is not None else default
def check_dataset_num_shards_shard_id(num_shards, shard_id):
if (num_shards is None) != (shard_id is None):
# These two parameters appear together.
raise ValueError("num_shards and shard_id need to be passed in together.")
if num_shards is not None:
check_pos_int32(num_shards, "num_shards")
if shard_id >= num_shards:
raise ValueError("shard_id should be less than num_shards.")

View File

@ -51,6 +51,7 @@ from mindspore.common import Tensor
from mindspore import log as logger from mindspore import log as logger
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
from mindspore.parallel._utils import _get_device_num from mindspore.parallel._utils import _get_device_num
from mindspore.dataset.engine.offload import GetOffloadModel, op_to_model
import mindspore.dataset.transforms.py_transforms as py_transforms import mindspore.dataset.transforms.py_transforms as py_transforms
@ -92,6 +93,29 @@ ShuffleToShuffleMode = {Shuffle.FILES: cde.ShuffleMode.FILES,
Shuffle.INFILE: cde.ShuffleMode.INFILE} Shuffle.INFILE: cde.ShuffleMode.INFILE}
def get_offloadable_ops(operations):
"""
Check if operations are supported by offload hardware accelarator.
Args:
operations: list of operations.
Returns:
Dictionary with boolean key for each operation for offload support.
"""
is_offloadable = {}
if not isinstance(operations, list):
operations = [operations]
for op in operations:
name = op.__class__.__name__
if name in op_to_model:
is_offloadable[name] = True
else:
is_offloadable[name] = False
return is_offloadable
def shuffle_to_shuffle_mode(shuffle): def shuffle_to_shuffle_mode(shuffle):
""" """
Shuffle Enum to Shuffle Mode Shuffle Enum to Shuffle Mode
@ -650,7 +674,8 @@ class Dataset:
@check_map @check_map
def map(self, operations, input_columns=None, output_columns=None, column_order=None, def map(self, operations, input_columns=None, output_columns=None, column_order=None,
num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None, max_rowsize=16): num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None,
max_rowsize=16, offload=False):
""" """
Apply each operation in operations to this dataset. Apply each operation in operations to this dataset.
@ -690,8 +715,9 @@ class Dataset:
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None, which means no cache is used). (default=None, which means no cache is used).
callbacks (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called (Default=None). callbacks (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called (Default=None).
max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy max_rowsize (int, optional): Maximum size of row in MB that is used for shared memory allocation to copy
data between processes. This is only used if python_multiprocessing is set to True (default=16). data between processes. This is only used if python_multiprocessing is set to True (Default=16).
offload (bool, optional): Flag to indicate whether offload is used (Default=False).
Returns: Returns:
@ -785,7 +811,7 @@ class Dataset:
""" """
return MapDataset(self, operations, input_columns, output_columns, column_order, num_parallel_workers, return MapDataset(self, operations, input_columns, output_columns, column_order, num_parallel_workers,
python_multiprocessing, cache, callbacks, max_rowsize) python_multiprocessing, cache, callbacks, max_rowsize, offload)
@check_filter @check_filter
def filter(self, predicate, input_columns=None, num_parallel_workers=None): def filter(self, predicate, input_columns=None, num_parallel_workers=None):
@ -2767,13 +2793,15 @@ class MapDataset(Dataset):
callbacks (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called (Default=None) callbacks (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called (Default=None)
max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy
data between processes. This is only used if python_multiprocessing is set to True (default=16). data between processes. This is only used if python_multiprocessing is set to True (default=16).
offload (bool, optional): Flag to indicate whether offload is used (Default=False).
Raises: Raises:
ValueError: If len(input_columns) != len(output_columns) and column_order is not specified. ValueError: If len(input_columns) != len(output_columns) and column_order is not specified.
""" """
def __init__(self, input_dataset, operations=None, input_columns=None, output_columns=None, column_order=None, def __init__(self, input_dataset, operations=None, input_columns=None, output_columns=None, column_order=None,
num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None, max_rowsize=16): num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None, max_rowsize=16,
offload=False):
super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers, cache=cache) super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers, cache=cache)
self.operations = to_list(operations) self.operations = to_list(operations)
self.operations = py_transforms.Compose.reduce(self.operations) self.operations = py_transforms.Compose.reduce(self.operations)
@ -2799,6 +2827,20 @@ class MapDataset(Dataset):
self.callbacks = to_list(callbacks) self.callbacks = to_list(callbacks)
self.max_rowsize = max_rowsize self.max_rowsize = max_rowsize
self.offload = offload
if self.offload is True:
offloadable_ops = get_offloadable_ops(operations)
cannot_offload = False
invalid_ops = []
for op in offloadable_ops:
if offloadable_ops[op] is not True:
cannot_offload = True
invalid_ops.append(op)
if cannot_offload is True:
logger.warning(("In map(), offload is set to True, but offload is not supported for the following "
"operation(s): {} \nSetting offload to False").format(*invalid_ops))
self.offload = False
def parse(self, children=None): def parse(self, children=None):
operations = [] operations = []
@ -2810,7 +2852,7 @@ class MapDataset(Dataset):
callbacks = [cb.create_runtime_obj() for cb in self.callbacks] callbacks = [cb.create_runtime_obj() for cb in self.callbacks]
return cde.MapNode(children[0], operations, self.input_columns, self.output_columns, self.column_order, return cde.MapNode(children[0], operations, self.input_columns, self.output_columns, self.column_order,
callbacks) callbacks, self.max_rowsize, self.offload)
def __deepcopy__(self, memodict): def __deepcopy__(self, memodict):
return self.__safe_deepcopy__(memodict, exclude=("operations", "callbacks", "__transfer_dataset__")) return self.__safe_deepcopy__(memodict, exclude=("operations", "callbacks", "__transfer_dataset__"))
@ -3210,6 +3252,13 @@ class _ToDevice:
def __deepcopy__(self, memodict): def __deepcopy__(self, memodict):
return self return self
def get_offload_model(self):
"""
Get offload model containing removed offload ops from pipeline.
"""
offload_model = GetOffloadModel(self._to_device)
return offload_model
class TransferDataset(Dataset): class TransferDataset(Dataset):
""" """
@ -3287,6 +3336,12 @@ class TransferDataset(Dataset):
return self._to_device.get_data_info() return self._to_device.get_data_info()
raise RuntimeError("Calling get_data_info with bad state.") raise RuntimeError("Calling get_data_info with bad state.")
def get_offload_model(self):
if self._to_device is not None:
return self._to_device.get_offload_model()
raise RuntimeError("get_offload_model, _to_device is None")
def release(self): def release(self):
""" """
Manually terminate Device Queue instead of relying on out of scope destruction. Manually terminate Device Queue instead of relying on out of scope destruction.
@ -6478,6 +6533,7 @@ class _Flowers102Dataset:
""" """
Mainly for loading Flowers102 Dataset, and return one row each time. Mainly for loading Flowers102 Dataset, and return one row each time.
""" """
def __init__(self, dataset_dir, task, usage, decode): def __init__(self, dataset_dir, task, usage, decode):
self.dataset_dir = os.path.realpath(dataset_dir) self.dataset_dir = os.path.realpath(dataset_dir)
self.task = task self.task = task

View File

@ -20,8 +20,9 @@ import signal
import weakref import weakref
import numpy as np import numpy as np
from mindspore.common.tensor import Tensor
import mindspore._c_dataengine as cde import mindspore._c_dataengine as cde
from mindspore.common.tensor import Tensor
import mindspore.dataset.engine.offload as offload
from mindspore import log as logger from mindspore import log as logger
@ -86,6 +87,10 @@ class Iterator:
self._transform_tensor = lambda t: Tensor.from_numpy(t.as_array()) self._transform_tensor = lambda t: Tensor.from_numpy(t.as_array())
self.__index = 0 self.__index = 0
self.offload_model = None
if offload.check_map_offload(self.__ori_dataset):
self.offload_model = offload.GetOffloadModel(consumer)
ITERATORS_LIST.append(weakref.ref(self)) ITERATORS_LIST.append(weakref.ref(self))
_unset_iterator_cleanup() _unset_iterator_cleanup()
@ -139,6 +144,10 @@ class Iterator:
self.__ori_dataset.dataset_size = self.__index self.__ori_dataset.dataset_size = self.__index
raise StopIteration raise StopIteration
self.__index += 1 self.__index += 1
if self.offload_model is not None:
data = offload.apply_offload_iterators(data, self.offload_model)
return data return data
def __deepcopy__(self, memo): def __deepcopy__(self, memo):

View File

@ -0,0 +1,368 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Offload Support.
"""
import json
import numpy as np
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
import mindspore.nn as nn
import mindspore.ops.composite as C
from mindspore.ops import operations as P
def check_map_offload(dataset):
"""
Check if offload flag is set in data pipeline map ops.
"""
offload_ckeck = False
dataset_tmp = dataset
while dataset_tmp:
if hasattr(dataset_tmp, 'offload'):
if dataset_tmp.offload is True:
offload_ckeck = True
if dataset_tmp.children:
dataset_tmp = dataset_tmp.children[0]
continue
dataset_tmp = dataset_tmp.children
if offload_ckeck is True:
if len(dataset.children) > 1:
raise RuntimeError("Offload currently does not support concatenated datasets.")
return offload_ckeck
def apply_offload_iterators(data, offload_model):
"""
Apply offload for non sink mode pipeline.
"""
if len(data) != 2:
# A temporary solution to ensure there are two columns in dataset.
raise RuntimeError("Offload can currently only use datasets with two columns.")
if isinstance(data[0], Tensor) is True:
data[0] = offload_model(data[0])
else:
data[0] = Tensor(data[0], dtype=mstype.float32)
data[0] = offload_model(data[0]).asnumpy()
return data
class ApplyPreTransform(nn.Cell):
"""
Concatenates offload model with network.
"""
def __init__(self, transform, model):
super(ApplyPreTransform, self).__init__(auto_prefix=False, flags=model.get_flags())
self.transform = transform
self.model = model
def construct(self, x, label):
x = self.transform(x)
x = self.model(x, label)
return x
class IdentityCell(nn.Cell):
"""
Applies identity transform on given input tensors.
"""
def __init__(self):
super(IdentityCell, self).__init__()
self.identity = P.Identity()
def construct(self, x):
return self.identity(x)
class RandomHorizontalFlip(nn.Cell):
"""
Applies Random Horizontal Flip transform on given input tensors.
"""
def __init__(self, prob):
super(RandomHorizontalFlip, self).__init__()
self.prob = Tensor(prob, dtype=mstype.float32)
self.cast = P.Cast()
self.shape = P.Shape()
self.uniformReal = P.UniformReal()
self.reshape = P.Reshape()
self.h_flip = P.ReverseV2(axis=[2])
self.mul = P.Mul()
def construct(self, x):
x = self.cast(x, mstype.float32)
bs, h, w, c = self.shape(x)
flip_rand_factor = self.uniformReal((bs, 1))
flip_rand_factor = self.cast((self.prob > flip_rand_factor), mstype.float32)
flip_rand_factor = self.reshape(C.repeat_elements(flip_rand_factor, rep=(h*w*c)), (bs, h, w, c))
x_flip = self.h_flip(x)
x = self.mul(x_flip, flip_rand_factor) + self.mul((1 - flip_rand_factor), x)
return x
class RandomVerticalFlip(nn.Cell):
"""
Applies Random Vertical Flip transform on given input tensors.
"""
def __init__(self, prob):
super(RandomVerticalFlip, self).__init__()
self.prob = Tensor(prob, dtype=mstype.float32)
self.cast = P.Cast()
self.shape = P.Shape()
self.uniformReal = P.UniformReal()
self.reshape = P.Reshape()
self.h_flip = P.ReverseV2(axis=[1])
self.mul = P.Mul()
def construct(self, x):
x = self.cast(x, mstype.float32)
bs, h, w, c = self.shape(x)
flip_rand_factor = self.uniformReal((bs, 1))
flip_rand_factor = self.cast((self.prob > flip_rand_factor), mstype.float32)
flip_rand_factor = self.reshape(C.repeat_elements(flip_rand_factor, rep=(h*w*c)), (bs, h, w, c))
x_flip = self.h_flip(x)
x = self.mul(x_flip, flip_rand_factor) + self.mul((1 - flip_rand_factor), x)
return x
class RandomColorAdjust(nn.Cell):
"""
Applies Random Color Adjust transform on given input tensors.
"""
def __init__(self, brightness, saturation):
super(RandomColorAdjust, self).__init__()
if isinstance(brightness, (list, tuple)):
self.br_min = brightness[0]
self.br_max = brightness[1]
else:
self.br_min = max(0, 1 - brightness)
self.br_max = 1 + brightness
if isinstance(saturation, (list, tuple)):
self.sa_min = saturation[0]
self.sa_max = saturation[1]
else:
self.sa_min = max(0, 1 - saturation)
self.sa_max = 1 + saturation
self.cast = P.Cast()
self.shape = P.Shape()
self.uniformReal = P.UniformReal()
self.reshape = P.Reshape()
self.unstack = P.Unstack(axis=-1)
self.expand_dims = P.ExpandDims()
self.mul = P.Mul()
def construct(self, x):
x = self.cast(x, mstype.float32)
bs, h, w, c = self.shape(x)
br_rand_factor = self.br_min + (self.br_max - self.br_min)*self.uniformReal((bs, 1))
br_rand_factor = self.reshape(C.repeat_elements(br_rand_factor, rep=(h*w*c)), (bs, h, w, c))
sa_rand_factor = self.sa_min + (self.sa_max - self.sa_min)*self.uniformReal((bs, 1))
sa_rand_factor = self.reshape(C.repeat_elements(sa_rand_factor, rep=(h*w*c)), (bs, h, w, c))
r, g, b = self.unstack(x)
x_gray = C.repeat_elements(self.expand_dims((0.2989 * r + 0.587 * g + 0.114 * b), -1), rep=c, axis=-1)
x = self.mul(x, br_rand_factor)
x = C.clip_by_value(x, 0.0, 255.0)
x = self.mul(x, sa_rand_factor) + self.mul((1 - sa_rand_factor), x_gray)
x = C.clip_by_value(x, 0.0, 255.0)
return x
class RandomSharpness(nn.Cell):
"""
Applies Random Sharpness transform on given input tensors.
"""
def __init__(self, degrees):
super(RandomSharpness, self).__init__()
if isinstance(degrees, (list, tuple)):
self.degree_min = degrees[0]
self.degree_max = degrees[1]
else:
self.degree_min = max(0, 1 - degrees)
self.degree_max = 1 + degrees
self.cast = P.Cast()
self.shape = P.Shape()
self.uniformReal = P.UniformReal()
self.reshape = P.Reshape()
self.expand_dims = P.ExpandDims()
self.mul = P.Mul()
self.transpose = P.Transpose()
self.weight = np.array([[1, 1, 1], [1, 5, 1], [1, 1, 1]])/13.0
self.weight = np.repeat(self.weight[np.newaxis, :, :], 3, axis=0)
self.weight = np.repeat(self.weight[np.newaxis, :, :], 3, axis=0)
self.weight = Tensor(self.weight, mstype.float32)
self.filter = P.Conv2D(out_channel=3, kernel_size=(3, 3), pad_mode='same')
def construct(self, x):
x = self.cast(x, mstype.float32)
bs, h, w, c = self.shape(x)
degree_rand_factor = self.degree_min + (self.degree_max - self.degree_min)*self.uniformReal((bs, 1))
degree_rand_factor = self.reshape(C.repeat_elements(degree_rand_factor, rep=(h*w*c)), (bs, h, w, c))
x_sharp = self.filter(self.transpose(x, (0, 3, 1, 2)), self.weight)
x_sharp = self.transpose(x_sharp, (0, 2, 3, 1))
x = self.mul(x, degree_rand_factor) + self.mul((1 - degree_rand_factor), x_sharp)
x = C.clip_by_value(x, 0.0, 255.0)
return x
class Rescale(nn.Cell):
"""
Applies Rescale transform on given input tensors.
"""
def __init__(self, rescale, shift):
super(Rescale, self).__init__()
self.rescale = Tensor(rescale, dtype=mstype.float32)
self.shift = Tensor(shift, dtype=mstype.float32)
self.cast = P.Cast()
self.mul = P.Mul()
def construct(self, x):
x = self.cast(x, mstype.float32)
x = x * self.rescale + self.shift
return x
class HwcToChw(nn.Cell):
"""
Applies Channel Swap transform on given input tensors.
"""
def __init__(self):
super(HwcToChw, self).__init__()
self.trans = P.Transpose()
def construct(self, x):
return self.trans(x, (0, 3, 1, 2))
class Normalize(nn.Cell):
"""
Applies Normalize transform on given input tensors.
"""
def __init__(self, mean, std):
super(Normalize, self).__init__()
self.mean = Tensor(mean, mstype.float32)
self.std = Tensor(std, mstype.float32)
self.sub = P.Sub()
self.div = P.Div()
self.cast = P.Cast()
def construct(self, x):
x = self.cast(x, mstype.float32)
x = self.sub(x, self.mean)
x = self.div(x, self.std)
return x
class OffloadModel():
def __init__(self, func, args_names=None):
self.func = func
self.args_names = args_names
# Dictionary connecting operation name to model
op_to_model = {
"HWC2CHW": OffloadModel(HwcToChw),
"HwcToChw": OffloadModel(HwcToChw),
"Normalize": OffloadModel(Normalize, ["std", "mean"]),
"RandomColorAdjust": OffloadModel(RandomColorAdjust, ["brightness", "saturation"]),
"RandomHorizontalFlip": OffloadModel(RandomHorizontalFlip, ["prob"]),
"RandomSharpness": OffloadModel(RandomSharpness, ["degrees"]),
"RandomVerticalFlip": OffloadModel(RandomVerticalFlip, ["prob"]),
"Rescale": OffloadModel(Rescale, ["rescale", "shift"])
}
class GetModelFromJson2Col(nn.Cell):
"""
Generates offload ME model from offload JSON file for a single map op.
"""
def __init__(self, json_offload):
super(GetModelFromJson2Col, self).__init__()
self.me_ops = []
if json_offload is not None:
offload_ops = json_offload["operations"]
for op in offload_ops:
name = op["tensor_op_name"]
args = op["tensor_op_params"]
op_model = op_to_model[name]
op_model_inputs = []
if op_model.args_names is not None:
for arg_key in op_model.args_names:
op_model_inputs.append(args[arg_key])
self.me_ops.append(op_model.func(*op_model_inputs))
else:
raise RuntimeError("Offload hardware accelarator cannot be applied for this pipeline.")
self.cell = nn.SequentialCell(self.me_ops)
def construct(self, x):
return self.cell(x)
class GetOffloadModel(nn.Cell):
"""
Generates offload ME model.
"""
def __init__(self, dataset_consumer):
super(GetOffloadModel, self).__init__()
self.transform_list = []
json_offload = json.loads(dataset_consumer.GetOffload())
if json_offload is not None:
for node in json_offload:
if node["op_type"] == 'Map':
self.transform_list.append(GetModelFromJson2Col(node))
self.transform_list.reverse()
def construct(self, x):
for transform in self.transform_list:
x = transform(x)
return x

View File

@ -26,7 +26,7 @@ from mindspore._c_expression import typing
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \ from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \
INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \ INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_gnn_list_of_pair_or_ndarray, \ validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_gnn_list_of_pair_or_ndarray, \
check_num_parallel_workers, check_columns, check_pos_int32, check_valid_str check_num_parallel_workers, check_columns, check_pos_int32, check_valid_str, check_dataset_num_shards_shard_id
from . import datasets from . import datasets
from . import samplers from . import samplers
@ -548,13 +548,7 @@ def check_generatordataset(method):
num_shards = param_dict.get("num_shards") num_shards = param_dict.get("num_shards")
shard_id = param_dict.get("shard_id") shard_id = param_dict.get("shard_id")
if (num_shards is None) != (shard_id is None): check_dataset_num_shards_shard_id(num_shards, shard_id)
# These two parameters appear together.
raise ValueError("num_shards and shard_id need to be passed in together.")
if num_shards is not None:
check_pos_int32(num_shards, "num_shards")
if shard_id >= num_shards:
raise ValueError("shard_id should be less than num_shards.")
sampler = param_dict.get("sampler") sampler = param_dict.get("sampler")
if sampler is not None: if sampler is not None:
@ -776,7 +770,7 @@ def check_map(method):
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
from mindspore.dataset.callback import DSCallback from mindspore.dataset.callback import DSCallback
[_, input_columns, output_columns, column_order, num_parallel_workers, python_multiprocessing, cache, [_, input_columns, output_columns, column_order, num_parallel_workers, python_multiprocessing, cache,
callbacks, max_rowsize], _ = \ callbacks, max_rowsize, offload], _ = \
parse_user_args(method, *args, **kwargs) parse_user_args(method, *args, **kwargs)
nreq_param_columns = ['input_columns', 'output_columns', 'column_order'] nreq_param_columns = ['input_columns', 'output_columns', 'column_order']
@ -788,6 +782,7 @@ def check_map(method):
type_check(python_multiprocessing, (bool,), "python_multiprocessing") type_check(python_multiprocessing, (bool,), "python_multiprocessing")
check_cache_option(cache) check_cache_option(cache)
type_check(max_rowsize, (int,), "max_rowsize") type_check(max_rowsize, (int,), "max_rowsize")
type_check(offload, (bool,), "offload")
if callbacks is not None: if callbacks is not None:
if isinstance(callbacks, (list, tuple)): if isinstance(callbacks, (list, tuple)):

View File

@ -118,6 +118,18 @@ def _generate_network_with_dataset(network, dataset_helper, queue_name):
return network return network
def _check_add_offload(dataset, dataset_helper, network):
from mindspore.dataset.engine import offload
if offload.check_map_offload(dataset.__transfer_dataset__):
# A temporary solution to ensure there are two columns in dataset.
dataset_types, _ = dataset_helper.types_shapes()
if len(dataset_types) != 2:
raise RuntimeError("Offload can currently only use datasets with two columns.")
offload_model = dataset.__transfer_dataset__.get_offload_model()
network = offload.ApplyPreTransform(offload_model, network)
return network
def connect_network_with_dataset(network, dataset_helper): def connect_network_with_dataset(network, dataset_helper):
""" """
Connect the `network` with dataset in `dataset_helper`. Connect the `network` with dataset in `dataset_helper`.
@ -153,7 +165,6 @@ def connect_network_with_dataset(network, dataset_helper):
>>> net = Net() >>> net = Net()
>>> net_with_get_next = connect_network_with_dataset(net, dataset_helper) >>> net_with_get_next = connect_network_with_dataset(net, dataset_helper)
""" """
dataset_iter = dataset_helper.iter dataset_iter = dataset_helper.iter
dataset = dataset_iter.dataset dataset = dataset_iter.dataset
@ -191,6 +202,7 @@ def connect_network_with_dataset(network, dataset_helper):
not context.get_context("enable_ge") and \ not context.get_context("enable_ge") and \
context.get_context("device_target") in ("Ascend", "GPU"): context.get_context("device_target") in ("Ascend", "GPU"):
dataset.__me_inited__ = True dataset.__me_inited__ = True
network = _check_add_offload(dataset, dataset_helper, network)
network = _generate_network_with_dataset(network, dataset_helper, queue_name) network = _generate_network_with_dataset(network, dataset_helper, queue_name)
if _dynamic_sink_data(dataset, dataset_iter) and _dynamic_sink_exception_scenario(dataset_iter): if _dynamic_sink_data(dataset, dataset_iter) and _dynamic_sink_exception_scenario(dataset_iter):

View File

@ -0,0 +1,47 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as C
DATA_DIR = "../data/dataset/testPK/data"
def test_offload():
"""
Feature: test map offload flag.
Description: Input is image dataset.
Expectation: Output should be same with activated or deactivated offload.
"""
# Dataset with offload activated.
dataset_0 = ds.ImageFolderDataset(DATA_DIR)
dataset_0 = dataset_0.map(operations=[C.Decode()], input_columns="image")
dataset_0 = dataset_0.map(operations=[C.HWC2CHW()], input_columns="image", offload=True)
dataset_0 = dataset_0.batch(8, drop_remainder=True)
# Dataset with offload not activated.
dataset_1 = ds.ImageFolderDataset(DATA_DIR)
dataset_1 = dataset_1.map(operations=[C.Decode()], input_columns="image")
dataset_1 = dataset_1.map(operations=[C.HWC2CHW()], input_columns="image")
dataset_1 = dataset_1.batch(8, drop_remainder=True)
for (img_0, _), (img_1, _) in zip(dataset_0.create_tuple_iterator(num_epochs=1, output_numpy=True),
dataset_1.create_tuple_iterator(num_epochs=1, output_numpy=True)):
np.testing.assert_array_equal(img_0, img_1)
if __name__ == "__main__":
test_offload()