!27081 [MD][Offload] Move offload checks and split to the C pipeline

Merge pull request !27081 from markuskunej/move_offload_check_to_c
This commit is contained in:
i-robot 2021-12-07 15:55:16 +00:00 committed by Gitee
commit 75e35bded2
10 changed files with 142 additions and 166 deletions

View File

@ -193,11 +193,12 @@ PYBIND_REGISTER(MapNode, 2, ([](const py::module *m) {
.def(py::init([](std::shared_ptr<DatasetNode> self, py::list operations, py::list input_columns,
py::list output_columns, py::list project_columns,
std::vector<std::shared_ptr<PyDSCallback>> py_callbacks, int64_t max_rowsize,
bool offload) {
int offload) {
auto map = std::make_shared<MapNode>(
self, std::move(toTensorOperations(operations)), toStringVector(input_columns),
toStringVector(output_columns), toStringVector(project_columns), nullptr,
std::vector<std::shared_ptr<DSCallback>>(py_callbacks.begin(), py_callbacks.end()), offload);
std::vector<std::shared_ptr<DSCallback>>(py_callbacks.begin(), py_callbacks.end()),
static_cast<ManualOffloadMode>(offload));
THROW_IF_ERROR(map->ValidateParams());
return map;
}));

View File

@ -35,7 +35,7 @@ namespace dataset {
MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations,
std::vector<std::string> input_columns, std::vector<std::string> output_columns,
const std::vector<std::string> &project_columns, std::shared_ptr<DatasetCache> cache,
std::vector<std::shared_ptr<DSCallback>> callbacks, bool offload)
std::vector<std::shared_ptr<DSCallback>> callbacks, ManualOffloadMode offload)
: operations_(operations),
input_columns_(input_columns),
output_columns_(output_columns),
@ -150,7 +150,7 @@ void MapNode::setOperations(const std::vector<std::shared_ptr<TensorOperation>>
}
std::vector<std::shared_ptr<TensorOperation>> MapNode::operations() { return operations_; }
void MapNode::SetOffload(bool offload) { offload_ = offload; }
void MapNode::SetOffload(ManualOffloadMode offload) { offload_ = offload; }
Status MapNode::to_json(nlohmann::json *out_json) {
RETURN_UNEXPECTED_IF_NULL(out_json);

View File

@ -26,13 +26,17 @@
namespace mindspore {
namespace dataset {
/// \brief Enum for the manual offload state
enum class ManualOffloadMode { UNSPECIFIED = 0, DISABLED, ENABLED };
class MapNode : public DatasetNode {
public:
/// \brief Constructor
MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations,
std::vector<std::string> input_columns = {}, std::vector<std::string> output_columns = {},
const std::vector<std::string> &columns = {}, std::shared_ptr<DatasetCache> cache = nullptr,
std::vector<std::shared_ptr<DSCallback>> callbacks = {}, bool offload = false);
std::vector<std::shared_ptr<DSCallback>> callbacks = {},
ManualOffloadMode offload = ManualOffloadMode::UNSPECIFIED);
/// \brief Destructor
~MapNode() = default;
@ -87,10 +91,10 @@ class MapNode : public DatasetNode {
const std::vector<std::string> &OutputColumns() const { return output_columns_; }
const std::vector<std::string> &ProjectColumns() const { return project_columns_; }
const std::vector<std::shared_ptr<DSCallback>> &Callbacks() const { return callbacks_; }
bool GetOffload() const { return offload_; }
ManualOffloadMode GetOffload() const { return offload_; }
/// \brief setter to set offload flag of node
void SetOffload(bool offload);
void SetOffload(ManualOffloadMode offload);
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
@ -123,8 +127,8 @@ class MapNode : public DatasetNode {
std::vector<std::string> project_columns_;
std::vector<std::shared_ptr<DSCallback>> callbacks_;
/// \brief Flag to indicate whether offload is set for the Map node.
bool offload_;
/// \brief ManualOffloadMode to indicate manual_offload status
ManualOffloadMode offload_;
};
} // namespace dataset

View File

@ -13,29 +13,100 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include "minddata/dataset/engine/opt/pre/node_offload_pass.h"
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/kernels/ir/tensor_operation.h"
namespace mindspore {
namespace dataset {
NodeOffloadPass::OffloadNodes::OffloadNodes() : prev_map_offloaded_(true) {}
NodeOffloadPass::OffloadNodes::OffloadNodes()
: prev_map_offloaded_(true), auto_offload_(GlobalContext::config_manager()->get_auto_offload()) {}
// Perform MapNode offload check.
Status NodeOffloadPass::OffloadNodes::Visit(std::shared_ptr<MapNode> node, bool *const modified) {
*modified = false;
// Check if this node is set to offload and add to nodes_to_offload_.
if (node->GetOffload() == true) {
ManualOffloadMode manual_offload = node->GetOffload();
bool offload_successful = false;
// Check if the node is set to manually offload, or if auto_offload is enabled while manual offload is not False.
if ((manual_offload == ManualOffloadMode::ENABLED) ||
((auto_offload_ == true) && (manual_offload != ManualOffloadMode::DISABLED))) {
bool offload_supported = true;
MS_LOG(INFO) << "Pre pass: node offload of map class is true.";
if (prev_map_offloaded_) {
nodes_to_offload_.push_back(std::static_pointer_cast<DatasetNode>(node));
} else {
MS_LOG(WARNING) << "Invalid use of offload in map, ignoring offload flag. Ops will be run in CPU pipeline";
node->SetOffload(false);
*modified = true;
// Currently offload not supported for different output_columns.
if (node->InputColumns() != node->OutputColumns()) {
MS_LOG(WARNING) << "Cannot offload map operation with output_columns != input_columns. Turning offload off.";
offload_supported = false;
}
} else {
// Check if map operation is at the end of the pipeline.
if (!prev_map_offloaded_) {
MS_LOG(WARNING) << "Map operation is not at the end of the pipeline (there exists a non-offloaded map after this "
"one). Turning offload off.";
offload_supported = false;
}
if (offload_supported) {
std::vector<std::string> invalid_ops;
std::vector<std::shared_ptr<TensorOperation>> temp_operations = node->operations();
bool all_valid_ops = true;
int last_invalid_op_pos = 1;
int pos = 1;
// Check individual operations to see if they are supported by offload.
for (auto operation : temp_operations) {
std::string op_name = operation->Name();
if (supported_ops_.find(op_name) == supported_ops_.end()) {
last_invalid_op_pos = pos;
invalid_ops.push_back(op_name);
all_valid_ops = false;
}
pos++;
}
if (all_valid_ops) {
// All operations can be offloaded.
nodes_to_offload_.push_back(std::static_pointer_cast<DatasetNode>(node));
offload_successful = true;
} else {
// Some operation(s) cannot be offloaded.
MS_LOG(WARNING)
<< "In Map Node, offload is set to True, but offload is not supported by the following operation(s): "
<< invalid_ops;
// See if the operations can be split into two Map Nodes
if (last_invalid_op_pos != temp_operations.size()) {
MS_LOG(WARNING) << "Map operation will be split after " << invalid_ops.back()
<< ", with the second map operation being offloaded.";
std::vector<std::shared_ptr<TensorOperation>> non_offload_ops(temp_operations.begin(),
temp_operations.begin() + last_invalid_op_pos);
std::vector<std::shared_ptr<TensorOperation>> offload_ops(temp_operations.begin() + last_invalid_op_pos,
temp_operations.end());
// First set operations to offload_ops to prepare for copy
node->setOperations(offload_ops);
// Copy node (returns a copy of the node, but without children)
std::shared_ptr<DatasetNode> offload_node = node->Copy();
// Set the number of parallel workers of the new node to be the same as current one.
offload_node->SetNumWorkers(node->NumWorkers());
node->setOperations(non_offload_ops);
// Insert the split offload map node above the original map node in the ir tree.
node->InsertAbove(offload_node);
// Add the offload map node to nodes_to_offload
nodes_to_offload_.push_back(offload_node);
} else {
MS_LOG(WARNING) << "No operations can be offloaded through splitting.";
}
}
}
}
if (!offload_successful) {
// Offload of the original node without modification did not take place.
// Since map nodes are visited in reverse order, no other map ops can be offloaded after this.
prev_map_offloaded_ = false;
}

View File

@ -18,6 +18,8 @@
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_NODE_OFFLOAD_PASS_H_
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "minddata/dataset/engine/opt/pass.h"
@ -49,8 +51,16 @@ class NodeOffloadPass : public IRTreePass {
std::vector<std::shared_ptr<DatasetNode>> nodes_to_offload() { return nodes_to_offload_; }
private:
/// \brief Vector of nodes to offload
std::vector<std::shared_ptr<DatasetNode>> nodes_to_offload_;
/// \brief Vector of supported offload operations
const std::set<std::string> supported_ops_{
"HwcToChw", "Normalize", "RandomColorAdjust", "RandomHorizontalFlip", "RandomSharpness",
"RandomVerticalFlip", "Rescale"};
/// \brief bool indicating if the map op is at the end of the pipeline
bool prev_map_offloaded_;
/// \brief bool indicating whether the auto_offload config option is enabled
bool auto_offload_;
};
public:

View File

@ -52,7 +52,7 @@ from mindspore.common import Tensor
from mindspore import log as logger
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
from mindspore.parallel._utils import _get_device_num
from mindspore.dataset.engine.offload import GetOffloadModel, op_to_model
from mindspore.dataset.engine.offload import GetOffloadModel
import mindspore.dataset.transforms.py_transforms as py_transforms
@ -73,7 +73,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_yes_no_dataset, check_speech_commands_dataset, check_tedlium_dataset, check_svhn_dataset, \
check_stl10_dataset
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
get_prefetch_size, get_auto_offload
get_prefetch_size
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
from ..core.validator_helpers import replace_none
from ..core.py_util_helpers import ExceptionHandler
@ -87,6 +87,17 @@ except ModuleNotFoundError:
if platform.system().lower() == "darwin":
multiprocessing.set_start_method("fork")
class ManualOffloadMode(Enum):
UNSPECIFIED = 0
DISABLED = 1
ENABLED = 2
OffloadToManualOffloadMode = {
None: ManualOffloadMode.UNSPECIFIED,
False: ManualOffloadMode.DISABLED,
True: ManualOffloadMode.ENABLED
}
class Shuffle(str, Enum):
GLOBAL: str = "global"
FILES: str = "files"
@ -98,95 +109,6 @@ ShuffleToShuffleMode = {Shuffle.FILES: cde.ShuffleMode.FILES,
Shuffle.INFILE: cde.ShuffleMode.INFILE}
def get_offloadable_ops(operations):
"""
Check if operations are supported by offload hardware accelerator.
Args:
operations: list of operations.
Returns:
Dictionary with boolean key for each operation for offload support.
"""
is_offloadable = {}
if not isinstance(operations, list):
operations = [operations]
for op in operations:
name = op.__class__.__name__
if name in op_to_model:
is_offloadable[name] = True
else:
is_offloadable[name] = False
return is_offloadable
def check_offload_map(operations, output_columns):
"""
Check if operations are supported by offload hardware accelerator. If not, see if list of operations can be split
into two: not offload supported and offload supported
Args:
operations: list of operations.
output_columns: list of names assigned to the columns outputted by the last operation.
Returns:
bool, indicates whether to use offload hardware accelarator.
bool, indicates whether list of map operations can be split.
list, first group of non-offload supported operations.
list, second group of offload supported operations.
"""
offloadable_ops = get_offloadable_ops(operations)
offload = True
can_split = False
offload_ops = []
non_offload_ops = []
invalid_ops = []
for op in offloadable_ops:
if offloadable_ops[op] is not True:
offload = False
invalid_ops.append(op)
if not offload:
logger.warning(("In map(), offload is set to True, but offload is not supported for the following "
"operation(s): {}").format(*invalid_ops))
if output_columns:
# Cannot split (currently), unsure which side of operations would alter the output columns
logger.warning("Since output_columns is specified, the list of operations cannot be split. "
"Unsure which operation(s) alter the columns. Setting offload to False.")
else:
# See if the map operator can be split and then offloaded
size = len(offloadable_ops)
idx = size
split_idx = size
op_names = list(offloadable_ops.keys())
for op_name in reversed(op_names):
if not offloadable_ops[op_name]:
# From reverse order, this op cannot be offloaded, therefore split here.
split_idx = idx
break
idx = idx - 1
if split_idx == size:
# The last op in the list cannot be offloaded, therefore nothing can be offloaded.
# Nothing to split.
logger.warning(("The last operation, {}, is not supported by offload, setting offload"
" to False").format(op_names[split_idx - 1]))
elif split_idx != 0:
# There are at least 1 offloadable ops at the end of the list.
# Split map() after the last non-offloadable op and only offload the second list of operations.
can_split = True
non_offload_ops = operations[:split_idx]
offload_ops = operations[split_idx:]
logger.warning(("The list of operations in map() can be split into two: {}, {}\n"
"The second list of operations will be run with offload=True"
).format(op_names[:split_idx], op_names[split_idx:]))
return offload, can_split, non_offload_ops, offload_ops
def shuffle_to_shuffle_mode(shuffle):
"""
Shuffle Enum to Shuffle Mode
@ -880,28 +802,8 @@ class Dataset:
... output_columns=["mod2", "mod3", "mod5", "mod7"],
... column_order=["mod7", "mod3", "col2"])
"""
can_split = False
non_offload_ops = []
offload_ops = []
if offload is not None:
offload_flag = offload
else:
offload_flag = get_auto_offload()
if offload_flag:
offload_flag, can_split, non_offload_ops, offload_ops = check_offload_map(operations, output_columns)
if can_split:
non_offload_map_ds = MapDataset(self, non_offload_ops, input_columns, output_columns, column_order,
num_parallel_workers, python_multiprocessing, cache, callbacks,
max_rowsize, offload=False)
return MapDataset(non_offload_map_ds, offload_ops, input_columns, output_columns, column_order,
num_parallel_workers, python_multiprocessing, cache, callbacks, max_rowsize,
offload=True)
return MapDataset(self, operations, input_columns, output_columns, column_order, num_parallel_workers,
python_multiprocessing, cache, callbacks, max_rowsize, offload_flag)
python_multiprocessing, cache, callbacks, max_rowsize, offload)
@check_filter
def filter(self, predicate, input_columns=None, num_parallel_workers=None):
@ -2881,7 +2783,7 @@ class MapDataset(Dataset):
callbacks (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called (Default=None)
max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy
data between processes. This is only used if python_multiprocessing is set to True (default=16).
offload (bool, optional): Flag to indicate whether offload is used (Default=False).
offload (bool, optional): Flag to indicate whether offload is used (Default=None).
Raises:
ValueError: If len(input_columns) != len(output_columns) and column_order is not specified.
@ -2889,7 +2791,7 @@ class MapDataset(Dataset):
def __init__(self, input_dataset, operations=None, input_columns=None, output_columns=None, column_order=None,
num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None, max_rowsize=16,
offload=False):
offload=None):
super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers, cache=cache)
self.operations = to_list(operations)
self.operations = py_transforms.Compose.reduce(self.operations)
@ -2915,7 +2817,8 @@ class MapDataset(Dataset):
self.callbacks = to_list(callbacks)
self.max_rowsize = max_rowsize
self.offload = offload
self.offload = OffloadToManualOffloadMode[offload]
def parse(self, children=None):
operations = []
@ -2927,7 +2830,7 @@ class MapDataset(Dataset):
callbacks = [cb.create_runtime_obj() for cb in self.callbacks]
return cde.MapNode(children[0], operations, self.input_columns, self.output_columns, self.column_order,
callbacks, self.max_rowsize, self.offload)
callbacks, self.max_rowsize, self.offload.value)
def __deepcopy__(self, memodict):
return self.__safe_deepcopy__(memodict, exclude=("operations", "callbacks", "__transfer_dataset__"))

View File

@ -88,8 +88,13 @@ class Iterator:
self.__index = 0
self.offload_model = None
if offload.check_map_offload(self.__ori_dataset):
self.offload_model = offload.GetOffloadModel(consumer)
offload_model = offload.GetOffloadModel(consumer)
# See if GetOffloadModel identified any operations set to be offloaded.
if offload_model.transform_list != []:
offload.check_concat_zip_dataset(self.__ori_dataset)
self.offload_model = offload_model
ITERATORS_LIST.append(weakref.ref(self))
_unset_iterator_cleanup()

View File

@ -30,34 +30,11 @@ def check_concat_zip_dataset(dataset):
"""
while dataset:
if len(dataset.children) > 1:
return True
raise RuntimeError("Offload module currently does not support concatenated or zipped datasets.")
if dataset.children:
dataset = dataset.children[0]
continue
dataset = dataset.children
return False
def check_map_offload(dataset):
"""
Check if offload flag is set in data pipeline map ops.
"""
offload_check = False
concat_zip_check = check_concat_zip_dataset(dataset)
while dataset:
if hasattr(dataset, 'offload'):
if dataset.offload is True:
offload_check = True
break
if dataset.children:
dataset = dataset.children[0]
else:
dataset = []
if offload_check and concat_zip_check:
raise RuntimeError("Offload module currently does not support concatenated or zipped datasets.")
return offload_check
def apply_offload_iterators(data, offload_model):

View File

@ -118,13 +118,16 @@ def _generate_network_with_dataset(network, dataset_helper, queue_name):
def _check_add_offload(dataset, dataset_helper, network):
"""Check if any map operations were removed to be offloaded and apply the transforms if so."""
from mindspore.dataset.engine import offload
if offload.check_map_offload(dataset.__transfer_dataset__):
offload_model = dataset.__transfer_dataset__.get_offload_model()
# See if the offload pass identified any operations to be offloaded
if offload_model.transform_list != []:
offload.check_concat_zip_dataset(dataset.__transfer_dataset__)
# A temporary solution to ensure there are two columns in dataset.
dataset_types, _ = dataset_helper.types_shapes()
if len(dataset_types) != 2:
raise RuntimeError("Offload can currently only use datasets with two columns.")
offload_model = dataset.__transfer_dataset__.get_offload_model()
network = offload.ApplyPreTransform(offload_model, network)
return network

View File

@ -53,13 +53,15 @@ def test_auto_offload():
"""
trans = [C.Decode(), C.HWC2CHW()]
# Dataset with config.auto_offload not activated
# Enable automatic offload
ds.config.set_auto_offload(True)
# Dataset with offload deactivated
dataset_auto_disabled = ds.ImageFolderDataset(DATA_DIR)
dataset_auto_disabled = dataset_auto_disabled.map(operations=trans, input_columns="image")
dataset_auto_disabled = dataset_auto_disabled.map(operations=trans, input_columns="image", offload=False)
dataset_auto_disabled = dataset_auto_disabled.batch(8, drop_remainder=True)
# Dataset with config.auto_offload activated
ds.config.set_auto_offload(True)
dataset_auto_enabled = ds.ImageFolderDataset(DATA_DIR)
dataset_auto_enabled = dataset_auto_enabled.map(operations=trans, input_columns="image")
dataset_auto_enabled = dataset_auto_enabled.batch(8, drop_remainder=True)