forked from mindspore-Ecosystem/mindspore
!25473 Model support dataset that shape is dynamic
Merge pull request !25473 from wangnan39/model_adapt_dynamic_shape_dataset
This commit is contained in:
commit
020a54398a
|
@ -97,7 +97,15 @@ bool Shape::operator==(const BaseShape &other) const {
|
||||||
if (tid() != other.tid()) {
|
if (tid() != other.tid()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return shape_ == static_cast<const Shape &>(other).shape_;
|
Shape other_shape = static_cast<const Shape &>(other);
|
||||||
|
bool shape_equal = shape_ == other_shape.shape_;
|
||||||
|
|
||||||
|
if (!IsDynamic() || !other_shape.IsDynamic()) {
|
||||||
|
return shape_equal;
|
||||||
|
}
|
||||||
|
bool min_shape_equel = min_shape_ == other_shape.min_shape_;
|
||||||
|
bool max_shape_equel = max_shape_ == other_shape.max_shape_;
|
||||||
|
return shape_equal && min_shape_equel && max_shape_equel;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t Shape::SHP_ANY;
|
const int64_t Shape::SHP_ANY;
|
||||||
|
|
|
@ -38,19 +38,35 @@ def _send_data_no_flag(dataset, epoch_num):
|
||||||
exec_dataset.send(epoch_num)
|
exec_dataset.send(epoch_num)
|
||||||
|
|
||||||
|
|
||||||
def _dynamic_sink_scenario(dataset, dataset_iter):
|
def _dynamic_sink_data(dataset, dataset_iter):
|
||||||
"""Special scenario with dynamic shape and sink_size=1."""
|
"""Special scenario for dataset with sink_size=1."""
|
||||||
flag = False
|
|
||||||
ms_role = os.getenv("MS_ROLE")
|
|
||||||
if hasattr(dataset_iter, "sink_size") and \
|
if hasattr(dataset_iter, "sink_size") and \
|
||||||
dataset_iter.sink_size == 1 and \
|
dataset_iter.sink_size == 1 and \
|
||||||
dataset.get_dataset_size() != 1 and \
|
dataset.get_dataset_size() != 1 and \
|
||||||
hasattr(dataset_iter, "sink_count") and \
|
hasattr(dataset_iter, "sink_count") and \
|
||||||
dataset_iter.sink_count == 1 and \
|
dataset_iter.sink_count == 1 and \
|
||||||
context.get_context("device_target") == "Ascend" and \
|
context.get_context("device_target") == "Ascend":
|
||||||
context.get_context("mode") == context.GRAPH_MODE and \
|
return True
|
||||||
ms_role != "MS_WORKER":
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _dynamic_sink_exception_scenario(dataset_iter):
|
||||||
|
"""The exception scenario for dynamic data is not applicable."""
|
||||||
|
ms_role = os.getenv("MS_ROLE")
|
||||||
|
_, dataset_shapes = dataset_iter.types_shapes()
|
||||||
|
|
||||||
|
if _has_dynamic_shape(dataset_shapes) or ms_role == "MS_WORKER" or \
|
||||||
|
context.get_context("mode") != context.GRAPH_MODE:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _dynamic_sink_scenario(dataset, dataset_iter):
|
||||||
|
"""Special scenario with dynamic shape and sink_size=1."""
|
||||||
|
flag = False
|
||||||
|
if _dynamic_sink_data(dataset, dataset_iter) and not _dynamic_sink_exception_scenario(dataset_iter):
|
||||||
flag = True
|
flag = True
|
||||||
|
|
||||||
return flag
|
return flag
|
||||||
|
|
||||||
|
|
||||||
|
@ -86,7 +102,7 @@ def _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queu
|
||||||
return network
|
return network
|
||||||
|
|
||||||
|
|
||||||
def has_dynamic_shape(dataset_shapes):
|
def _has_dynamic_shape(dataset_shapes):
|
||||||
for shape in dataset_shapes:
|
for shape in dataset_shapes:
|
||||||
if -1 in shape:
|
if -1 in shape:
|
||||||
return True
|
return True
|
||||||
|
@ -95,7 +111,7 @@ def has_dynamic_shape(dataset_shapes):
|
||||||
|
|
||||||
def _generate_network_with_dataset(network, dataset_helper, queue_name):
|
def _generate_network_with_dataset(network, dataset_helper, queue_name):
|
||||||
dataset_types, dataset_shapes = dataset_helper.types_shapes()
|
dataset_types, dataset_shapes = dataset_helper.types_shapes()
|
||||||
(min_shapes, max_shapes) = (None, None) if not has_dynamic_shape(dataset_shapes) \
|
(min_shapes, max_shapes) = (None, None) if not _has_dynamic_shape(dataset_shapes) \
|
||||||
else dataset_helper.dynamic_min_max_shapes()
|
else dataset_helper.dynamic_min_max_shapes()
|
||||||
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types,
|
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types,
|
||||||
queue_name, min_shapes, max_shapes)
|
queue_name, min_shapes, max_shapes)
|
||||||
|
@ -177,13 +193,7 @@ def connect_network_with_dataset(network, dataset_helper):
|
||||||
dataset.__me_inited__ = True
|
dataset.__me_inited__ = True
|
||||||
network = _generate_network_with_dataset(network, dataset_helper, queue_name)
|
network = _generate_network_with_dataset(network, dataset_helper, queue_name)
|
||||||
|
|
||||||
if hasattr(dataset_iter, "sink_size") and \
|
if _dynamic_sink_data(dataset, dataset_iter) and _dynamic_sink_exception_scenario(dataset_iter):
|
||||||
dataset_iter.sink_size == 1 and \
|
|
||||||
dataset.get_dataset_size() != 1 and \
|
|
||||||
hasattr(dataset_iter, "sink_count") and \
|
|
||||||
dataset_iter.sink_count == 1 and \
|
|
||||||
context.get_context("device_target") == "Ascend" and \
|
|
||||||
context.get_context("mode") == context.PYNATIVE_MODE:
|
|
||||||
dataset_helper.get_data_info()
|
dataset_helper.get_data_info()
|
||||||
|
|
||||||
return network
|
return network
|
||||||
|
|
Loading…
Reference in New Issue