!25473 Model support dataset that shape is dynamic

Merge pull request !25473 from wangnan39/model_adapt_dynamic_shape_dataset
This commit is contained in:
i-robot 2021-10-30 03:21:00 +00:00 committed by Gitee
commit 020a54398a
2 changed files with 35 additions and 17 deletions

View File

@ -97,7 +97,15 @@ bool Shape::operator==(const BaseShape &other) const {
if (tid() != other.tid()) { if (tid() != other.tid()) {
return false; return false;
} }
return shape_ == static_cast<const Shape &>(other).shape_; Shape other_shape = static_cast<const Shape &>(other);
bool shape_equal = shape_ == other_shape.shape_;
if (!IsDynamic() || !other_shape.IsDynamic()) {
return shape_equal;
}
bool min_shape_equel = min_shape_ == other_shape.min_shape_;
bool max_shape_equel = max_shape_ == other_shape.max_shape_;
return shape_equal && min_shape_equel && max_shape_equel;
} }
const int64_t Shape::SHP_ANY; const int64_t Shape::SHP_ANY;

View File

@ -38,19 +38,35 @@ def _send_data_no_flag(dataset, epoch_num):
exec_dataset.send(epoch_num) exec_dataset.send(epoch_num)
def _dynamic_sink_scenario(dataset, dataset_iter): def _dynamic_sink_data(dataset, dataset_iter):
"""Special scenario with dynamic shape and sink_size=1.""" """Special scenario for dataset with sink_size=1."""
flag = False
ms_role = os.getenv("MS_ROLE")
if hasattr(dataset_iter, "sink_size") and \ if hasattr(dataset_iter, "sink_size") and \
dataset_iter.sink_size == 1 and \ dataset_iter.sink_size == 1 and \
dataset.get_dataset_size() != 1 and \ dataset.get_dataset_size() != 1 and \
hasattr(dataset_iter, "sink_count") and \ hasattr(dataset_iter, "sink_count") and \
dataset_iter.sink_count == 1 and \ dataset_iter.sink_count == 1 and \
context.get_context("device_target") == "Ascend" and \ context.get_context("device_target") == "Ascend":
context.get_context("mode") == context.GRAPH_MODE and \ return True
ms_role != "MS_WORKER": return False
def _dynamic_sink_exception_scenario(dataset_iter):
"""The exception scenario for dynamic data is not applicable."""
ms_role = os.getenv("MS_ROLE")
_, dataset_shapes = dataset_iter.types_shapes()
if _has_dynamic_shape(dataset_shapes) or ms_role == "MS_WORKER" or \
context.get_context("mode") != context.GRAPH_MODE:
return True
return False
def _dynamic_sink_scenario(dataset, dataset_iter):
"""Special scenario with dynamic shape and sink_size=1."""
flag = False
if _dynamic_sink_data(dataset, dataset_iter) and not _dynamic_sink_exception_scenario(dataset_iter):
flag = True flag = True
return flag return flag
@ -86,7 +102,7 @@ def _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queu
return network return network
def has_dynamic_shape(dataset_shapes): def _has_dynamic_shape(dataset_shapes):
for shape in dataset_shapes: for shape in dataset_shapes:
if -1 in shape: if -1 in shape:
return True return True
@ -95,7 +111,7 @@ def has_dynamic_shape(dataset_shapes):
def _generate_network_with_dataset(network, dataset_helper, queue_name): def _generate_network_with_dataset(network, dataset_helper, queue_name):
dataset_types, dataset_shapes = dataset_helper.types_shapes() dataset_types, dataset_shapes = dataset_helper.types_shapes()
(min_shapes, max_shapes) = (None, None) if not has_dynamic_shape(dataset_shapes) \ (min_shapes, max_shapes) = (None, None) if not _has_dynamic_shape(dataset_shapes) \
else dataset_helper.dynamic_min_max_shapes() else dataset_helper.dynamic_min_max_shapes()
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types,
queue_name, min_shapes, max_shapes) queue_name, min_shapes, max_shapes)
@ -177,13 +193,7 @@ def connect_network_with_dataset(network, dataset_helper):
dataset.__me_inited__ = True dataset.__me_inited__ = True
network = _generate_network_with_dataset(network, dataset_helper, queue_name) network = _generate_network_with_dataset(network, dataset_helper, queue_name)
if hasattr(dataset_iter, "sink_size") and \ if _dynamic_sink_data(dataset, dataset_iter) and _dynamic_sink_exception_scenario(dataset_iter):
dataset_iter.sink_size == 1 and \
dataset.get_dataset_size() != 1 and \
hasattr(dataset_iter, "sink_count") and \
dataset_iter.sink_count == 1 and \
context.get_context("device_target") == "Ascend" and \
context.get_context("mode") == context.PYNATIVE_MODE:
dataset_helper.get_data_info() dataset_helper.get_data_info()
return network return network