forked from mindspore-Ecosystem/mindspore
!25473 Model support dataset that shape is dynamic
Merge pull request !25473 from wangnan39/model_adapt_dynamic_shape_dataset
This commit is contained in:
commit
020a54398a
|
@ -97,7 +97,15 @@ bool Shape::operator==(const BaseShape &other) const {
|
|||
if (tid() != other.tid()) {
|
||||
return false;
|
||||
}
|
||||
return shape_ == static_cast<const Shape &>(other).shape_;
|
||||
Shape other_shape = static_cast<const Shape &>(other);
|
||||
bool shape_equal = shape_ == other_shape.shape_;
|
||||
|
||||
if (!IsDynamic() || !other_shape.IsDynamic()) {
|
||||
return shape_equal;
|
||||
}
|
||||
bool min_shape_equel = min_shape_ == other_shape.min_shape_;
|
||||
bool max_shape_equel = max_shape_ == other_shape.max_shape_;
|
||||
return shape_equal && min_shape_equel && max_shape_equel;
|
||||
}
|
||||
|
||||
const int64_t Shape::SHP_ANY;
|
||||
|
|
|
@ -38,19 +38,35 @@ def _send_data_no_flag(dataset, epoch_num):
|
|||
exec_dataset.send(epoch_num)
|
||||
|
||||
|
||||
def _dynamic_sink_scenario(dataset, dataset_iter):
|
||||
"""Special scenario with dynamic shape and sink_size=1."""
|
||||
flag = False
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
def _dynamic_sink_data(dataset, dataset_iter):
|
||||
"""Special scenario for dataset with sink_size=1."""
|
||||
if hasattr(dataset_iter, "sink_size") and \
|
||||
dataset_iter.sink_size == 1 and \
|
||||
dataset.get_dataset_size() != 1 and \
|
||||
hasattr(dataset_iter, "sink_count") and \
|
||||
dataset_iter.sink_count == 1 and \
|
||||
context.get_context("device_target") == "Ascend" and \
|
||||
context.get_context("mode") == context.GRAPH_MODE and \
|
||||
ms_role != "MS_WORKER":
|
||||
context.get_context("device_target") == "Ascend":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _dynamic_sink_exception_scenario(dataset_iter):
|
||||
"""The exception scenario for dynamic data is not applicable."""
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
_, dataset_shapes = dataset_iter.types_shapes()
|
||||
|
||||
if _has_dynamic_shape(dataset_shapes) or ms_role == "MS_WORKER" or \
|
||||
context.get_context("mode") != context.GRAPH_MODE:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _dynamic_sink_scenario(dataset, dataset_iter):
|
||||
"""Special scenario with dynamic shape and sink_size=1."""
|
||||
flag = False
|
||||
if _dynamic_sink_data(dataset, dataset_iter) and not _dynamic_sink_exception_scenario(dataset_iter):
|
||||
flag = True
|
||||
|
||||
return flag
|
||||
|
||||
|
||||
|
@ -86,7 +102,7 @@ def _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queu
|
|||
return network
|
||||
|
||||
|
||||
def has_dynamic_shape(dataset_shapes):
|
||||
def _has_dynamic_shape(dataset_shapes):
|
||||
for shape in dataset_shapes:
|
||||
if -1 in shape:
|
||||
return True
|
||||
|
@ -95,7 +111,7 @@ def has_dynamic_shape(dataset_shapes):
|
|||
|
||||
def _generate_network_with_dataset(network, dataset_helper, queue_name):
|
||||
dataset_types, dataset_shapes = dataset_helper.types_shapes()
|
||||
(min_shapes, max_shapes) = (None, None) if not has_dynamic_shape(dataset_shapes) \
|
||||
(min_shapes, max_shapes) = (None, None) if not _has_dynamic_shape(dataset_shapes) \
|
||||
else dataset_helper.dynamic_min_max_shapes()
|
||||
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types,
|
||||
queue_name, min_shapes, max_shapes)
|
||||
|
@ -177,13 +193,7 @@ def connect_network_with_dataset(network, dataset_helper):
|
|||
dataset.__me_inited__ = True
|
||||
network = _generate_network_with_dataset(network, dataset_helper, queue_name)
|
||||
|
||||
if hasattr(dataset_iter, "sink_size") and \
|
||||
dataset_iter.sink_size == 1 and \
|
||||
dataset.get_dataset_size() != 1 and \
|
||||
hasattr(dataset_iter, "sink_count") and \
|
||||
dataset_iter.sink_count == 1 and \
|
||||
context.get_context("device_target") == "Ascend" and \
|
||||
context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
if _dynamic_sink_data(dataset, dataset_iter) and _dynamic_sink_exception_scenario(dataset_iter):
|
||||
dataset_helper.get_data_info()
|
||||
|
||||
return network
|
||||
|
|
Loading…
Reference in New Issue